mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-25 00:34:19 +03:00
commit
bb3a564500
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
from .telegram_bare_client import TelegramBareClient
|
||||
from .telegram_client import TelegramClient
|
||||
from .client.telegramclient import TelegramClient
|
||||
from .network import connection
|
||||
from .tl import types, functions
|
||||
from . import version, events, utils
|
||||
|
|
22
telethon/client/__init__.py
Normal file
22
telethon/client/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""
|
||||
This package defines clients as subclasses of others, and then a single
|
||||
`telethon.client.telegramclient.TelegramClient` which is subclass of them
|
||||
all to provide the final unified interface while the methods can live in
|
||||
different subclasses to be more maintainable.
|
||||
|
||||
The ABC is `telethon.client.telegrambaseclient.TelegramBaseClient` and the
|
||||
first implementor is `telethon.client.users.UserMethods`, since calling
|
||||
requests require them to be resolved first, and that requires accessing
|
||||
entities (users).
|
||||
"""
|
||||
from .telegrambaseclient import TelegramBaseClient
|
||||
from .users import UserMethods # Required for everything
|
||||
from .messageparse import MessageParseMethods # Required for messages
|
||||
from .uploads import UploadMethods # Required for messages to send files
|
||||
from .messages import MessageMethods
|
||||
from .chats import ChatMethods
|
||||
from .dialogs import DialogMethods
|
||||
from .downloads import DownloadMethods
|
||||
from .auth import AuthMethods
|
||||
from .updates import UpdateMethods
|
||||
from .telegramclient import TelegramClient
|
428
telethon/client/auth.py
Normal file
428
telethon/client/auth.py
Normal file
|
@ -0,0 +1,428 @@
|
|||
import getpass
|
||||
import hashlib
|
||||
import sys
|
||||
|
||||
import os
|
||||
|
||||
from .messageparse import MessageParseMethods
|
||||
from .users import UserMethods
|
||||
from .. import utils, helpers, errors
|
||||
from ..tl import types, functions
|
||||
|
||||
|
||||
class AuthMethods(MessageParseMethods, UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
async def start(
|
||||
self,
|
||||
phone=lambda: input('Please enter your phone: '),
|
||||
password=lambda: getpass.getpass('Please enter your password: '),
|
||||
bot_token=None, force_sms=False, code_callback=None,
|
||||
first_name='New User', last_name=''):
|
||||
"""
|
||||
Convenience method to interactively connect and sign in if required,
|
||||
also taking into consideration that 2FA may be enabled in the account.
|
||||
|
||||
If the phone doesn't belong to an existing account (and will hence
|
||||
`sign_up` for a new one), **you are agreeing to Telegram's
|
||||
Terms of Service. This is required and your account
|
||||
will be banned otherwise.** See https://telegram.org/tos
|
||||
and https://core.telegram.org/api/terms.
|
||||
|
||||
Example usage:
|
||||
>>> client = ...
|
||||
>>> client.start(phone)
|
||||
Please enter the code you received: 12345
|
||||
Please enter your password: *******
|
||||
(You are now logged in)
|
||||
|
||||
Args:
|
||||
phone (`str` | `int` | `callable`):
|
||||
The phone (or callable without arguments to get it)
|
||||
to which the code will be sent.
|
||||
|
||||
password (`callable`, optional):
|
||||
The password for 2 Factor Authentication (2FA).
|
||||
This is only required if it is enabled in your account.
|
||||
|
||||
bot_token (`str`):
|
||||
Bot Token obtained by `@BotFather <https://t.me/BotFather>`_
|
||||
to log in as a bot. Cannot be specified with ``phone`` (only
|
||||
one of either allowed).
|
||||
|
||||
force_sms (`bool`, optional):
|
||||
Whether to force sending the code request as SMS.
|
||||
This only makes sense when signing in with a `phone`.
|
||||
|
||||
code_callback (`callable`, optional):
|
||||
A callable that will be used to retrieve the Telegram
|
||||
login code. Defaults to `input()`.
|
||||
|
||||
first_name (`str`, optional):
|
||||
The first name to be used if signing up. This has no
|
||||
effect if the account already exists and you sign in.
|
||||
|
||||
last_name (`str`, optional):
|
||||
Similar to the first name, but for the last. Optional.
|
||||
|
||||
Returns:
|
||||
This `TelegramClient`, so initialization
|
||||
can be chained with ``.start()``.
|
||||
"""
|
||||
|
||||
if code_callback is None:
|
||||
def code_callback():
|
||||
return input('Please enter the code you received: ')
|
||||
elif not callable(code_callback):
|
||||
raise ValueError(
|
||||
'The code_callback parameter needs to be a callable '
|
||||
'function that returns the code you received by Telegram.'
|
||||
)
|
||||
|
||||
if not phone and not bot_token:
|
||||
raise ValueError('No phone number or bot token provided.')
|
||||
|
||||
if phone and bot_token and not callable(phone):
|
||||
raise ValueError('Both a phone and a bot token provided, '
|
||||
'must only provide one of either')
|
||||
|
||||
if not self.is_connected():
|
||||
await self.connect()
|
||||
|
||||
if await self.is_user_authorized():
|
||||
return self
|
||||
|
||||
if bot_token:
|
||||
await self.sign_in(bot_token=bot_token)
|
||||
return self
|
||||
|
||||
# Turn the callable into a valid phone number
|
||||
while callable(phone):
|
||||
phone = utils.parse_phone(phone()) or phone
|
||||
|
||||
me = None
|
||||
attempts = 0
|
||||
max_attempts = 3
|
||||
two_step_detected = False
|
||||
|
||||
sent_code = await self.send_code_request(phone, force_sms=force_sms)
|
||||
sign_up = not sent_code.phone_registered
|
||||
while attempts < max_attempts:
|
||||
try:
|
||||
if sign_up:
|
||||
me = await self.sign_up(
|
||||
code_callback(), first_name, last_name)
|
||||
else:
|
||||
# Raises SessionPasswordNeededError if 2FA enabled
|
||||
me = await self.sign_in(phone, code_callback())
|
||||
break
|
||||
except errors.SessionPasswordNeededError:
|
||||
two_step_detected = True
|
||||
break
|
||||
except errors.PhoneNumberOccupiedError:
|
||||
sign_up = False
|
||||
except errors.PhoneNumberUnoccupiedError:
|
||||
sign_up = True
|
||||
except (errors.PhoneCodeEmptyError,
|
||||
errors.PhoneCodeExpiredError,
|
||||
errors.PhoneCodeHashEmptyError,
|
||||
errors.PhoneCodeInvalidError):
|
||||
print('Invalid code. Please try again.', file=sys.stderr)
|
||||
|
||||
attempts += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'{} consecutive sign-in attempts failed. Aborting'
|
||||
.format(max_attempts)
|
||||
)
|
||||
|
||||
if two_step_detected:
|
||||
if not password:
|
||||
raise ValueError(
|
||||
"Two-step verification is enabled for this account. "
|
||||
"Please provide the 'password' argument to 'start()'."
|
||||
)
|
||||
# TODO If callable given make it retry on invalid
|
||||
if callable(password):
|
||||
password = password()
|
||||
me = await self.sign_in(phone=phone, password=password)
|
||||
|
||||
# We won't reach here if any step failed (exit by exception)
|
||||
signed, name = 'Signed in successfully as', utils.get_display_name(me)
|
||||
try:
|
||||
print(signed, name)
|
||||
except UnicodeEncodeError:
|
||||
# Some terminals don't support certain characters
|
||||
print(signed, name.encode('utf-8', errors='ignore')
|
||||
.decode('ascii', errors='ignore'))
|
||||
|
||||
return self
|
||||
|
||||
async def is_user_authorized(self):
|
||||
return await self.get_me() is not None
|
||||
|
||||
async def sign_in(
|
||||
self, phone=None, code=None, password=None,
|
||||
bot_token=None, phone_code_hash=None):
|
||||
"""
|
||||
Starts or completes the sign in process with the given phone number
|
||||
or code that Telegram sent.
|
||||
|
||||
Args:
|
||||
phone (`str` | `int`):
|
||||
The phone to send the code to if no code was provided,
|
||||
or to override the phone that was previously used with
|
||||
these requests.
|
||||
|
||||
code (`str` | `int`):
|
||||
The code that Telegram sent. Note that if you have sent this
|
||||
code through the application itself it will immediately
|
||||
expire. If you want to send the code, obfuscate it somehow.
|
||||
If you're not doing any of this you can ignore this note.
|
||||
|
||||
password (`str`):
|
||||
2FA password, should be used if a previous call raised
|
||||
SessionPasswordNeededError.
|
||||
|
||||
bot_token (`str`):
|
||||
Used to sign in as a bot. Not all requests will be available.
|
||||
This should be the hash the @BotFather gave you.
|
||||
|
||||
phone_code_hash (`str`):
|
||||
The hash returned by .send_code_request. This can be set to None
|
||||
to use the last hash known.
|
||||
|
||||
Returns:
|
||||
The signed in user, or the information about
|
||||
:meth:`send_code_request`.
|
||||
"""
|
||||
me = await self.get_me()
|
||||
if me:
|
||||
return me
|
||||
|
||||
if phone and not code and not password:
|
||||
return await self.send_code_request(phone)
|
||||
elif code:
|
||||
phone = utils.parse_phone(phone) or self._phone
|
||||
phone_code_hash = \
|
||||
phone_code_hash or self._phone_code_hash.get(phone, None)
|
||||
|
||||
if not phone:
|
||||
raise ValueError(
|
||||
'Please make sure to call send_code_request first.'
|
||||
)
|
||||
if not phone_code_hash:
|
||||
raise ValueError('You also need to provide a phone_code_hash.')
|
||||
|
||||
# May raise PhoneCodeEmptyError, PhoneCodeExpiredError,
|
||||
# PhoneCodeHashEmptyError or PhoneCodeInvalidError.
|
||||
result = await self(functions.auth.SignInRequest(
|
||||
phone, phone_code_hash, str(code)))
|
||||
elif password:
|
||||
salt = (await self(
|
||||
functions.account.GetPasswordRequest())).current_salt
|
||||
result = await self(functions.auth.CheckPasswordRequest(
|
||||
helpers.get_password_hash(password, salt)
|
||||
))
|
||||
elif bot_token:
|
||||
result = await self(functions.auth.ImportBotAuthorizationRequest(
|
||||
flags=0, bot_auth_token=bot_token,
|
||||
api_id=self.api_id, api_hash=self.api_hash
|
||||
))
|
||||
else:
|
||||
raise ValueError(
|
||||
'You must provide a phone and a code the first time, '
|
||||
'and a password only if an RPCError was raised before.'
|
||||
)
|
||||
|
||||
self._self_input_peer = utils.get_input_peer(
|
||||
result.user, allow_self=False
|
||||
)
|
||||
return result.user
|
||||
|
||||
async def sign_up(self, code, first_name, last_name=''):
|
||||
"""
|
||||
Signs up to Telegram if you don't have an account yet.
|
||||
You must call .send_code_request(phone) first.
|
||||
|
||||
**By using this method you're agreeing to Telegram's
|
||||
Terms of Service. This is required and your account
|
||||
will be banned otherwise.** See https://telegram.org/tos
|
||||
and https://core.telegram.org/api/terms.
|
||||
|
||||
Args:
|
||||
code (`str` | `int`):
|
||||
The code sent by Telegram
|
||||
|
||||
first_name (`str`):
|
||||
The first name to be used by the new account.
|
||||
|
||||
last_name (`str`, optional)
|
||||
Optional last name.
|
||||
|
||||
Returns:
|
||||
The new created :tl:`User`.
|
||||
"""
|
||||
me = await self.get_me()
|
||||
if me:
|
||||
return me
|
||||
|
||||
if self._tos and self._tos.text:
|
||||
if self.parse_mode:
|
||||
t = self.parse_mode.unparse(self._tos.text, self._tos.entities)
|
||||
else:
|
||||
t = self._tos.text
|
||||
sys.stderr.write("{}\n".format(t))
|
||||
sys.stderr.flush()
|
||||
|
||||
result = await self(functions.auth.SignUpRequest(
|
||||
phone_number=self._phone,
|
||||
phone_code_hash=self._phone_code_hash.get(self._phone, ''),
|
||||
phone_code=str(code),
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
))
|
||||
|
||||
if self._tos:
|
||||
await self(
|
||||
functions.help.AcceptTermsOfServiceRequest(self._tos.id))
|
||||
|
||||
self._self_input_peer = utils.get_input_peer(
|
||||
result.user, allow_self=False
|
||||
)
|
||||
return result.user
|
||||
|
||||
async def send_code_request(self, phone, force_sms=False):
|
||||
"""
|
||||
Sends a code request to the specified phone number.
|
||||
|
||||
Args:
|
||||
phone (`str` | `int`):
|
||||
The phone to which the code will be sent.
|
||||
|
||||
force_sms (`bool`, optional):
|
||||
Whether to force sending as SMS.
|
||||
|
||||
Returns:
|
||||
An instance of :tl:`SentCode`.
|
||||
"""
|
||||
phone = utils.parse_phone(phone) or self._phone
|
||||
phone_hash = self._phone_code_hash.get(phone)
|
||||
|
||||
if not phone_hash:
|
||||
try:
|
||||
result = await self(functions.auth.SendCodeRequest(
|
||||
phone, self.api_id, self.api_hash))
|
||||
except errors.AuthRestartError:
|
||||
return self.send_code_request(phone, force_sms=force_sms)
|
||||
|
||||
self._tos = result.terms_of_service
|
||||
self._phone_code_hash[phone] = phone_hash = result.phone_code_hash
|
||||
else:
|
||||
force_sms = True
|
||||
|
||||
self._phone = phone
|
||||
|
||||
if force_sms:
|
||||
result = await self(
|
||||
functions.auth.ResendCodeRequest(phone, phone_hash))
|
||||
|
||||
self._phone_code_hash[phone] = result.phone_code_hash
|
||||
|
||||
return result
|
||||
|
||||
async def log_out(self):
|
||||
"""
|
||||
Logs out Telegram and deletes the current ``*.session`` file.
|
||||
|
||||
Returns:
|
||||
``True`` if the operation was successful.
|
||||
"""
|
||||
try:
|
||||
await self(functions.auth.LogOutRequest())
|
||||
except errors.RPCError:
|
||||
return False
|
||||
|
||||
await self.disconnect()
|
||||
self.session.delete()
|
||||
self._authorized = False
|
||||
return True
|
||||
|
||||
async def edit_2fa(
|
||||
self, current_password=None, new_password=None, hint='',
|
||||
email=None):
|
||||
"""
|
||||
Changes the 2FA settings of the logged in user, according to the
|
||||
passed parameters. Take note of the parameter explanations.
|
||||
|
||||
Has no effect if both current and new password are omitted.
|
||||
|
||||
current_password (`str`, optional):
|
||||
The current password, to authorize changing to ``new_password``.
|
||||
Must be set if changing existing 2FA settings.
|
||||
Must **not** be set if 2FA is currently disabled.
|
||||
Passing this by itself will remove 2FA (if correct).
|
||||
|
||||
new_password (`str`, optional):
|
||||
The password to set as 2FA.
|
||||
If 2FA was already enabled, ``current_password`` **must** be set.
|
||||
Leaving this blank or ``None`` will remove the password.
|
||||
|
||||
hint (`str`, optional):
|
||||
Hint to be displayed by Telegram when it asks for 2FA.
|
||||
Leaving unspecified is highly discouraged.
|
||||
Has no effect if ``new_password`` is not set.
|
||||
|
||||
email (`str`, optional):
|
||||
Recovery and verification email. Raises ``EmailUnconfirmedError``
|
||||
if value differs from current one, and has no effect if
|
||||
``new_password`` is not set.
|
||||
|
||||
Returns:
|
||||
``True`` if successful, ``False`` otherwise.
|
||||
"""
|
||||
if new_password is None and current_password is None:
|
||||
return False
|
||||
|
||||
pass_result = await self(functions.account.GetPasswordRequest())
|
||||
if isinstance(
|
||||
pass_result, types.account.NoPassword) and current_password:
|
||||
current_password = None
|
||||
|
||||
salt_random = os.urandom(8)
|
||||
salt = pass_result.new_salt + salt_random
|
||||
if not current_password:
|
||||
current_password_hash = salt
|
||||
else:
|
||||
current_password = (
|
||||
pass_result.current_salt
|
||||
+ current_password.encode()
|
||||
+ pass_result.current_salt
|
||||
)
|
||||
current_password_hash = hashlib.sha256(current_password).digest()
|
||||
|
||||
if new_password: # Setting new password
|
||||
new_password = salt + new_password.encode('utf-8') + salt
|
||||
new_password_hash = hashlib.sha256(new_password).digest()
|
||||
new_settings = types.account.PasswordInputSettings(
|
||||
new_salt=salt,
|
||||
new_password_hash=new_password_hash,
|
||||
hint=hint
|
||||
)
|
||||
if email: # If enabling 2FA or changing email
|
||||
new_settings.email = email # TG counts empty string as None
|
||||
return await self(functions.account.UpdatePasswordSettingsRequest(
|
||||
current_password_hash, new_settings=new_settings
|
||||
))
|
||||
else: # Removing existing password
|
||||
return await self(functions.account.UpdatePasswordSettingsRequest(
|
||||
current_password_hash,
|
||||
new_settings=types.account.PasswordInputSettings(
|
||||
new_salt=bytes(),
|
||||
new_password_hash=bytes(),
|
||||
hint=hint
|
||||
)
|
||||
))
|
||||
|
||||
# endregion
|
188
telethon/client/chats.py
Normal file
188
telethon/client/chats.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
from collections import UserList
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils
|
||||
from ..tl import types, functions
|
||||
|
||||
|
||||
class ChatMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
@async_generator
|
||||
async def iter_participants(
|
||||
self, entity, limit=None, search='',
|
||||
filter=None, aggressive=False, _total=None):
|
||||
"""
|
||||
Iterator over the participants belonging to the specified chat.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
The entity from which to retrieve the participants list.
|
||||
|
||||
limit (`int`):
|
||||
Limits amount of participants fetched.
|
||||
|
||||
search (`str`, optional):
|
||||
Look for participants with this string in name/username.
|
||||
|
||||
filter (:tl:`ChannelParticipantsFilter`, optional):
|
||||
The filter to be used, if you want e.g. only admins
|
||||
Note that you might not have permissions for some filter.
|
||||
This has no effect for normal chats or users.
|
||||
|
||||
aggressive (`bool`, optional):
|
||||
Aggressively looks for all participants in the chat in
|
||||
order to get more than 10,000 members (a hard limit
|
||||
imposed by Telegram). Note that this might take a long
|
||||
time (over 5 minutes), but is able to return over 90,000
|
||||
participants on groups with 100,000 members.
|
||||
|
||||
This has no effect for groups or channels with less than
|
||||
10,000 members, or if a ``filter`` is given.
|
||||
|
||||
_total (`list`, optional):
|
||||
A single-item list to pass the total parameter by reference.
|
||||
|
||||
Yields:
|
||||
The :tl:`User` objects returned by :tl:`GetParticipantsRequest`
|
||||
with an additional ``.participant`` attribute which is the
|
||||
matched :tl:`ChannelParticipant` type for channels/megagroups
|
||||
or :tl:`ChatParticipants` for normal chats.
|
||||
"""
|
||||
if isinstance(filter, type):
|
||||
if filter in (types.ChannelParticipantsBanned,
|
||||
types.ChannelParticipantsKicked,
|
||||
types.ChannelParticipantsSearch):
|
||||
# These require a `q` parameter (support types for convenience)
|
||||
filter = filter('')
|
||||
else:
|
||||
filter = filter()
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
if search and (filter
|
||||
or not isinstance(entity, types.InputPeerChannel)):
|
||||
# We need to 'search' ourselves unless we have a PeerChannel
|
||||
search = search.lower()
|
||||
|
||||
def filter_entity(ent):
|
||||
return search in utils.get_display_name(ent).lower() or\
|
||||
search in (getattr(ent, 'username', '') or None).lower()
|
||||
else:
|
||||
def filter_entity(ent):
|
||||
return True
|
||||
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
if _total or (aggressive and not filter):
|
||||
total = (await self(functions.channels.GetFullChannelRequest(
|
||||
entity
|
||||
))).full_chat.participants_count
|
||||
if _total:
|
||||
_total[0] = total
|
||||
else:
|
||||
total = 0
|
||||
|
||||
if limit == 0:
|
||||
return
|
||||
|
||||
seen = set()
|
||||
if total > 10000 and aggressive and not filter:
|
||||
requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=types.ChannelParticipantsSearch(search + chr(x)),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
) for x in range(ord('a'), ord('z') + 1)]
|
||||
else:
|
||||
requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=filter or types.ChannelParticipantsSearch(search),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
)]
|
||||
|
||||
while requests:
|
||||
# Only care about the limit for the first request
|
||||
# (small amount of people, won't be aggressive).
|
||||
#
|
||||
# Most people won't care about getting exactly 12,345
|
||||
# members so it doesn't really matter not to be 100%
|
||||
# precise with being out of the offset/limit here.
|
||||
requests[0].limit = min(limit - requests[0].offset, 200)
|
||||
if requests[0].offset > limit:
|
||||
break
|
||||
|
||||
results = await self(requests)
|
||||
for i in reversed(range(len(requests))):
|
||||
participants = results[i]
|
||||
if not participants.users:
|
||||
requests.pop(i)
|
||||
else:
|
||||
requests[i].offset += len(participants.participants)
|
||||
users = {user.id: user for user in participants.users}
|
||||
for participant in participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not filter_entity(user) or user.id in seen:
|
||||
continue
|
||||
|
||||
seen.add(participant.user_id)
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
await yield_(user)
|
||||
if len(seen) >= limit:
|
||||
return
|
||||
|
||||
elif isinstance(entity, types.InputPeerChat):
|
||||
# TODO We *could* apply the `filter` here ourselves
|
||||
full = await self(
|
||||
functions.messages.GetFullChatRequest(entity.chat_id))
|
||||
if not isinstance(
|
||||
full.full_chat.participants, types.ChatParticipants):
|
||||
# ChatParticipantsForbidden won't have ``.participants``
|
||||
_total[0] = 0
|
||||
return
|
||||
|
||||
if _total:
|
||||
_total[0] = len(full.full_chat.participants.participants)
|
||||
|
||||
have = 0
|
||||
users = {user.id: user for user in full.users}
|
||||
for participant in full.full_chat.participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not filter_entity(user):
|
||||
continue
|
||||
have += 1
|
||||
if have > limit:
|
||||
break
|
||||
else:
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
await yield_(user)
|
||||
else:
|
||||
if _total:
|
||||
_total[0] = 1
|
||||
if limit != 0:
|
||||
user = await self.get_entity(entity)
|
||||
if filter_entity(user):
|
||||
user.participant = None
|
||||
await yield_(user)
|
||||
|
||||
async def get_participants(self, *args, **kwargs):
|
||||
"""
|
||||
Same as :meth:`iter_participants`, but returns a list instead
|
||||
with an additional ``.total`` attribute on the list.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
participants = UserList()
|
||||
async for x in self.iter_participants(*args, **kwargs):
|
||||
participants.append(x)
|
||||
participants.total = total[0]
|
||||
return participants
|
||||
|
||||
# endregion
|
136
telethon/client/dialogs.py
Normal file
136
telethon/client/dialogs.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import itertools
|
||||
from collections import UserList
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
|
||||
class DialogMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
@async_generator
|
||||
async def iter_dialogs(
|
||||
self, limit=None, offset_date=None, offset_id=0,
|
||||
offset_peer=types.InputPeerEmpty(), _total=None):
|
||||
"""
|
||||
Returns an iterator over the dialogs, yielding 'limit' at most.
|
||||
Dialogs are the open "chats" or conversations with other people,
|
||||
groups you have joined, or channels you are subscribed to.
|
||||
|
||||
Args:
|
||||
limit (`int` | `None`):
|
||||
How many dialogs to be retrieved as maximum. Can be set to
|
||||
``None`` to retrieve all dialogs. Note that this may take
|
||||
whole minutes if you have hundreds of dialogs, as Telegram
|
||||
will tell the library to slow down through a
|
||||
``FloodWaitError``.
|
||||
|
||||
offset_date (`datetime`, optional):
|
||||
The offset date to be used.
|
||||
|
||||
offset_id (`int`, optional):
|
||||
The message ID to be used as an offset.
|
||||
|
||||
offset_peer (:tl:`InputPeer`, optional):
|
||||
The peer to be used as an offset.
|
||||
|
||||
_total (`list`, optional):
|
||||
A single-item list to pass the total parameter by reference.
|
||||
|
||||
Yields:
|
||||
Instances of `telethon.tl.custom.dialog.Dialog`.
|
||||
"""
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if limit == 0:
|
||||
if not _total:
|
||||
return
|
||||
# Special case, get a single dialog and determine count
|
||||
dialogs = await self(functions.messages.GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=1
|
||||
))
|
||||
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
|
||||
return
|
||||
|
||||
seen = set()
|
||||
req = functions.messages.GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=0
|
||||
)
|
||||
while len(seen) < limit:
|
||||
req.limit = min(limit - len(seen), 100)
|
||||
r = await self(req)
|
||||
|
||||
if _total:
|
||||
_total[0] = getattr(r, 'count', len(r.dialogs))
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
messages = {m.id: custom.Message(self, m, entities, None)
|
||||
for m in r.messages}
|
||||
|
||||
# Happens when there are pinned dialogs
|
||||
if len(r.dialogs) > limit:
|
||||
r.dialogs = r.dialogs[:limit]
|
||||
|
||||
for d in r.dialogs:
|
||||
peer_id = utils.get_peer_id(d.peer)
|
||||
if peer_id not in seen:
|
||||
seen.add(peer_id)
|
||||
await yield_(custom.Dialog(self, d, entities, messages))
|
||||
|
||||
if len(r.dialogs) < req.limit\
|
||||
or not isinstance(r, types.messages.DialogsSlice):
|
||||
# Less than we requested means we reached the end, or
|
||||
# we didn't get a DialogsSlice which means we got all.
|
||||
break
|
||||
|
||||
req.offset_date = r.messages[-1].date
|
||||
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
||||
req.offset_id = r.messages[-1].id
|
||||
req.exclude_pinned = True
|
||||
|
||||
async def get_dialogs(self, *args, **kwargs):
|
||||
"""
|
||||
Same as :meth:`iter_dialogs`, but returns a list instead
|
||||
with an additional ``.total`` attribute on the list.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
dialogs = UserList()
|
||||
async for x in self.iter_dialogs(*args, **kwargs):
|
||||
dialogs.append(x)
|
||||
dialogs.total = total[0]
|
||||
return dialogs
|
||||
|
||||
@async_generator
|
||||
async def iter_drafts(self): # TODO: Ability to provide a `filter`
|
||||
"""
|
||||
Iterator over all open draft messages.
|
||||
|
||||
Instances of `telethon.tl.custom.draft.Draft` are yielded.
|
||||
You can call `telethon.tl.custom.draft.Draft.set_message`
|
||||
to change the message or `telethon.tl.custom.draft.Draft.delete`
|
||||
among other things.
|
||||
"""
|
||||
r = await self(functions.messages.GetAllDraftsRequest())
|
||||
for update in r.updates:
|
||||
await yield_(custom.Draft._from_update(self, update))
|
||||
|
||||
async def get_drafts(self):
|
||||
"""
|
||||
Same as :meth:`iter_drafts`, but returns a list instead.
|
||||
"""
|
||||
result = []
|
||||
async for x in self.iter_drafts():
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
# endregion
|
414
telethon/client/downloads.py
Normal file
414
telethon/client/downloads.py
Normal file
|
@ -0,0 +1,414 @@
|
|||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils, helpers, errors
|
||||
from ..crypto import CdnDecrypter
|
||||
from ..tl import TLObject, types, functions
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
async def download_profile_photo(
|
||||
self, entity, file=None, download_big=True):
|
||||
"""
|
||||
Downloads the profile photo of the given entity (user/chat/channel).
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
From who the photo will be downloaded.
|
||||
|
||||
file (`str` | `file`, optional):
|
||||
The output file path, directory, or stream-like object.
|
||||
If the path exists and is a file, it will be overwritten.
|
||||
|
||||
download_big (`bool`, optional):
|
||||
Whether to use the big version of the available photos.
|
||||
|
||||
Returns:
|
||||
``None`` if no photo was provided, or if it was Empty. On success
|
||||
the file path is returned since it may differ from the one given.
|
||||
"""
|
||||
# hex(crc32(x.encode('ascii'))) for x in
|
||||
# ('User', 'Chat', 'UserFull', 'ChatFull')
|
||||
ENTITIES = (0x2da17977, 0xc5af5d94, 0x1f4661b9, 0xd49a2697)
|
||||
# ('InputPeer', 'InputUser', 'InputChannel')
|
||||
INPUTS = (0xc91c90b6, 0xe669bf46, 0x40f202fd)
|
||||
if not isinstance(entity, TLObject) or entity.SUBCLASS_OF_ID in INPUTS:
|
||||
entity = await self.get_entity(entity)
|
||||
|
||||
possible_names = []
|
||||
if entity.SUBCLASS_OF_ID not in ENTITIES:
|
||||
photo = entity
|
||||
else:
|
||||
if not hasattr(entity, 'photo'):
|
||||
# Special case: may be a ChatFull with photo:Photo
|
||||
# This is different from a normal UserProfilePhoto and Chat
|
||||
if not hasattr(entity, 'chat_photo'):
|
||||
return None
|
||||
|
||||
return await self._download_photo(
|
||||
entity.chat_photo, file, date=None, progress_callback=None)
|
||||
|
||||
for attr in ('username', 'first_name', 'title'):
|
||||
possible_names.append(getattr(entity, attr, None))
|
||||
|
||||
photo = entity.photo
|
||||
|
||||
if isinstance(photo, (types.UserProfilePhoto, types.ChatPhoto)):
|
||||
loc = photo.photo_big if download_big else photo.photo_small
|
||||
else:
|
||||
try:
|
||||
loc = utils.get_input_location(photo)
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
file = self._get_proper_filename(
|
||||
file, 'profile_photo', '.jpg',
|
||||
possible_names=possible_names
|
||||
)
|
||||
|
||||
try:
|
||||
await self.download_file(loc, file)
|
||||
return file
|
||||
except errors.LocationInvalidError:
|
||||
# See issue #500, Android app fails as of v4.6.0 (1155).
|
||||
# The fix seems to be using the full channel chat photo.
|
||||
ie = await self.get_input_entity(entity)
|
||||
if isinstance(ie, types.InputPeerChannel):
|
||||
full = await self(functions.channels.GetFullChannelRequest(ie))
|
||||
return await self._download_photo(
|
||||
full.full_chat.chat_photo, file,
|
||||
date=None, progress_callback=None
|
||||
)
|
||||
else:
|
||||
# Until there's a report for chats, no need to.
|
||||
return None
|
||||
|
||||
async def download_media(self, message, file=None, progress_callback=None):
|
||||
"""
|
||||
Downloads the given media, or the media from a specified Message.
|
||||
|
||||
Note that if the download is too slow, you should consider installing
|
||||
``cryptg`` (through ``pip install cryptg``) so that decrypting the
|
||||
received data is done in C instead of Python (much faster).
|
||||
|
||||
message (:tl:`Message` | :tl:`Media`):
|
||||
The media or message containing the media that will be downloaded.
|
||||
|
||||
file (`str` | `file`, optional):
|
||||
The output file path, directory, or stream-like object.
|
||||
If the path exists and is a file, it will be overwritten.
|
||||
|
||||
progress_callback (`callable`, optional):
|
||||
A callback function accepting two parameters:
|
||||
``(received bytes, total)``.
|
||||
|
||||
Returns:
|
||||
``None`` if no media was provided, or if it was Empty. On success
|
||||
the file path is returned since it may differ from the one given.
|
||||
"""
|
||||
# TODO This won't work for messageService
|
||||
if isinstance(message, types.Message):
|
||||
date = message.date
|
||||
media = message.media
|
||||
else:
|
||||
date = datetime.datetime.now()
|
||||
media = message
|
||||
|
||||
if isinstance(media, types.MessageMediaWebPage):
|
||||
if isinstance(media.webpage, types.WebPage):
|
||||
media = media.webpage.document or media.webpage.photo
|
||||
|
||||
if isinstance(media, (types.MessageMediaPhoto, types.Photo,
|
||||
types.PhotoSize, types.PhotoCachedSize)):
|
||||
return await self._download_photo(
|
||||
media, file, date, progress_callback
|
||||
)
|
||||
elif isinstance(media, (types.MessageMediaDocument, types.Document)):
|
||||
return await self._download_document(
|
||||
media, file, date, progress_callback
|
||||
)
|
||||
elif isinstance(media, types.MessageMediaContact):
|
||||
return self._download_contact(
|
||||
media, file
|
||||
)
|
||||
|
||||
async def download_file(
|
||||
self, input_location, file=None, part_size_kb=None,
|
||||
file_size=None, progress_callback=None):
|
||||
"""
|
||||
Downloads the given input location to a file.
|
||||
|
||||
Args:
|
||||
input_location (:tl:`FileLocation` | :tl:`InputFileLocation`):
|
||||
The file location from which the file will be downloaded.
|
||||
See `telethon.utils.get_input_location` source for a complete
|
||||
list of supported types.
|
||||
|
||||
file (`str` | `file`, optional):
|
||||
The output file path, directory, or stream-like object.
|
||||
If the path exists and is a file, it will be overwritten.
|
||||
|
||||
If the file path is ``None``, then the result will be
|
||||
saved in memory and returned as `bytes`.
|
||||
|
||||
part_size_kb (`int`, optional):
|
||||
Chunk size when downloading files. The larger, the less
|
||||
requests will be made (up to 512KB maximum).
|
||||
|
||||
file_size (`int`, optional):
|
||||
The file size that is about to be downloaded, if known.
|
||||
Only used if ``progress_callback`` is specified.
|
||||
|
||||
progress_callback (`callable`, optional):
|
||||
A callback function accepting two parameters:
|
||||
``(downloaded bytes, total)``. Note that the
|
||||
``total`` is the provided ``file_size``.
|
||||
"""
|
||||
if not part_size_kb:
|
||||
if not file_size:
|
||||
part_size_kb = 64 # Reasonable default
|
||||
else:
|
||||
part_size_kb = utils.get_appropriated_part_size(file_size)
|
||||
|
||||
part_size = int(part_size_kb * 1024)
|
||||
# https://core.telegram.org/api/files says:
|
||||
# > part_size % 1024 = 0 (divisible by 1KB)
|
||||
#
|
||||
# But https://core.telegram.org/cdn (more recent) says:
|
||||
# > limit must be divisible by 4096 bytes
|
||||
# So we just stick to the 4096 limit.
|
||||
if part_size % 4096 != 0:
|
||||
raise ValueError(
|
||||
'The part size must be evenly divisible by 4096.')
|
||||
|
||||
in_memory = file is None
|
||||
if in_memory:
|
||||
f = io.BytesIO()
|
||||
elif isinstance(file, str):
|
||||
# Ensure that we'll be able to download the media
|
||||
helpers.ensure_parent_dir_exists(file)
|
||||
f = open(file, 'wb')
|
||||
else:
|
||||
f = file
|
||||
|
||||
# The used sender will change if ``FileMigrateError`` occurs
|
||||
sender = self._sender
|
||||
input_location = utils.get_input_location(input_location)
|
||||
|
||||
__log__.info('Downloading file in chunks of %d bytes', part_size)
|
||||
try:
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
result = await sender.send(functions.upload.GetFileRequest(
|
||||
input_location, offset, part_size
|
||||
))
|
||||
if isinstance(result, types.upload.FileCdnRedirect):
|
||||
# TODO Implement
|
||||
raise NotImplementedError
|
||||
except errors.FileMigrateError as e:
|
||||
__log__.info('File lives in another DC')
|
||||
sender = await self._get_exported_sender(e.new_dc)
|
||||
continue
|
||||
|
||||
offset += part_size
|
||||
if not result.bytes:
|
||||
if in_memory:
|
||||
f.flush()
|
||||
return f.getvalue()
|
||||
else:
|
||||
return getattr(result, 'type', '')
|
||||
|
||||
__log__.debug('Saving %d more bytes', len(result.bytes))
|
||||
f.write(result.bytes)
|
||||
if progress_callback:
|
||||
progress_callback(f.tell(), file_size)
|
||||
finally:
|
||||
if sender != self._sender:
|
||||
await sender.disconnect()
|
||||
if isinstance(file, str) or in_memory:
|
||||
f.close()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
async def _download_photo(self, photo, file, date, progress_callback):
|
||||
"""Specialized version of .download_media() for photos"""
|
||||
# Determine the photo and its largest size
|
||||
if isinstance(photo, types.MessageMediaPhoto):
|
||||
photo = photo.photo
|
||||
if isinstance(photo, types.Photo):
|
||||
for size in reversed(photo.sizes):
|
||||
if not isinstance(size, types.PhotoSizeEmpty):
|
||||
photo = size
|
||||
break
|
||||
else:
|
||||
return
|
||||
if not isinstance(photo, (types.PhotoSize, types.PhotoCachedSize)):
|
||||
return
|
||||
|
||||
file = self._get_proper_filename(file, 'photo', '.jpg', date=date)
|
||||
if isinstance(photo, types.PhotoCachedSize):
|
||||
# No need to download anything, simply write the bytes
|
||||
if isinstance(file, str):
|
||||
helpers.ensure_parent_dir_exists(file)
|
||||
f = open(file, 'wb')
|
||||
else:
|
||||
f = file
|
||||
try:
|
||||
f.write(photo.bytes)
|
||||
finally:
|
||||
if isinstance(file, str):
|
||||
f.close()
|
||||
return file
|
||||
|
||||
await self.download_file(
|
||||
photo.location, file, file_size=photo.size,
|
||||
progress_callback=progress_callback)
|
||||
return file
|
||||
|
||||
async def _download_document(
|
||||
self, document, file, date, progress_callback):
|
||||
"""Specialized version of .download_media() for documents."""
|
||||
if isinstance(document, types.MessageMediaDocument):
|
||||
document = document.document
|
||||
if not isinstance(document, types.Document):
|
||||
return
|
||||
|
||||
file_size = document.size
|
||||
|
||||
kind = 'document'
|
||||
possible_names = []
|
||||
for attr in document.attributes:
|
||||
if isinstance(attr, types.DocumentAttributeFilename):
|
||||
possible_names.insert(0, attr.file_name)
|
||||
|
||||
elif isinstance(attr, types.DocumentAttributeAudio):
|
||||
kind = 'audio'
|
||||
if attr.performer and attr.title:
|
||||
possible_names.append('{} - {}'.format(
|
||||
attr.performer, attr.title
|
||||
))
|
||||
elif attr.performer:
|
||||
possible_names.append(attr.performer)
|
||||
elif attr.title:
|
||||
possible_names.append(attr.title)
|
||||
elif attr.voice:
|
||||
kind = 'voice'
|
||||
|
||||
file = self._get_proper_filename(
|
||||
file, kind, utils.get_extension(document),
|
||||
date=date, possible_names=possible_names
|
||||
)
|
||||
|
||||
await self.download_file(
|
||||
document, file, file_size=file_size,
|
||||
progress_callback=progress_callback)
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
def _download_contact(cls, mm_contact, file):
|
||||
"""
|
||||
Specialized version of .download_media() for contacts.
|
||||
Will make use of the vCard 4.0 format.
|
||||
"""
|
||||
first_name = mm_contact.first_name
|
||||
last_name = mm_contact.last_name
|
||||
phone_number = mm_contact.phone_number
|
||||
|
||||
if isinstance(file, str):
|
||||
file = cls._get_proper_filename(
|
||||
file, 'contact', '.vcard',
|
||||
possible_names=[first_name, phone_number, last_name]
|
||||
)
|
||||
f = open(file, 'w', encoding='utf-8')
|
||||
else:
|
||||
f = file
|
||||
|
||||
try:
|
||||
# Remove these pesky characters
|
||||
first_name = first_name.replace(';', '')
|
||||
last_name = (last_name or '').replace(';', '')
|
||||
f.write('BEGIN:VCARD\n')
|
||||
f.write('VERSION:4.0\n')
|
||||
f.write('N:{};{};;;\n'.format(first_name, last_name))
|
||||
f.write('FN:{} {}\n'.format(first_name, last_name))
|
||||
f.write('TEL;TYPE=cell;VALUE=uri:tel:+{}\n'.format(phone_number))
|
||||
f.write('END:VCARD\n')
|
||||
finally:
|
||||
# Only close the stream if we opened it
|
||||
if isinstance(file, str):
|
||||
f.close()
|
||||
|
||||
return file
|
||||
|
||||
@staticmethod
|
||||
def _get_proper_filename(file, kind, extension,
|
||||
date=None, possible_names=None):
|
||||
"""Gets a proper filename for 'file', if this is a path.
|
||||
|
||||
'kind' should be the kind of the output file (photo, document...)
|
||||
'extension' should be the extension to be added to the file if
|
||||
the filename doesn't have any yet
|
||||
'date' should be when this file was originally sent, if known
|
||||
'possible_names' should be an ordered list of possible names
|
||||
|
||||
If no modification is made to the path, any existing file
|
||||
will be overwritten.
|
||||
If any modification is made to the path, this method will
|
||||
ensure that no existing file will be overwritten.
|
||||
"""
|
||||
if file is not None and not isinstance(file, str):
|
||||
# Probably a stream-like object, we cannot set a filename here
|
||||
return file
|
||||
|
||||
if file is None:
|
||||
file = ''
|
||||
elif os.path.isfile(file):
|
||||
# Make no modifications to valid existing paths
|
||||
return file
|
||||
|
||||
if os.path.isdir(file) or not file:
|
||||
try:
|
||||
name = None if possible_names is None else next(
|
||||
x for x in possible_names if x
|
||||
)
|
||||
except StopIteration:
|
||||
name = None
|
||||
|
||||
if not name:
|
||||
if not date:
|
||||
date = datetime.datetime.now()
|
||||
name = '{}_{}-{:02}-{:02}_{:02}-{:02}-{:02}'.format(
|
||||
kind,
|
||||
date.year, date.month, date.day,
|
||||
date.hour, date.minute, date.second,
|
||||
)
|
||||
file = os.path.join(file, name)
|
||||
|
||||
directory, name = os.path.split(file)
|
||||
name, ext = os.path.splitext(name)
|
||||
if not ext:
|
||||
ext = extension
|
||||
|
||||
result = os.path.join(directory, name + ext)
|
||||
if not os.path.isfile(result):
|
||||
return result
|
||||
|
||||
i = 1
|
||||
while True:
|
||||
result = os.path.join(directory, '{} ({}){}'.format(name, i, ext))
|
||||
if not os.path.isfile(result):
|
||||
return result
|
||||
i += 1
|
||||
|
||||
# endregion
|
129
telethon/client/messageparse.py
Normal file
129
telethon/client/messageparse.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import itertools
|
||||
import re
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils
|
||||
from ..tl import types, custom
|
||||
|
||||
|
||||
class MessageParseMethods(UserMethods):
|
||||
|
||||
# region Public properties
|
||||
|
||||
@property
|
||||
def parse_mode(self):
|
||||
"""
|
||||
This property is the default parse mode used when sending messages.
|
||||
Defaults to `telethon.extensions.markdown`. It will always
|
||||
be either ``None`` or an object with ``parse`` and ``unparse``
|
||||
methods.
|
||||
|
||||
When setting a different value it should be one of:
|
||||
|
||||
* Object with ``parse`` and ``unparse`` methods.
|
||||
* A ``callable`` to act as the parse method.
|
||||
* A ``str`` indicating the ``parse_mode``. For Markdown ``'md'``
|
||||
or ``'markdown'`` may be used. For HTML, ``'htm'`` or ``'html'``
|
||||
may be used.
|
||||
|
||||
The ``parse`` method should be a function accepting a single
|
||||
parameter, the text to parse, and returning a tuple consisting
|
||||
of ``(parsed message str, [MessageEntity instances])``.
|
||||
|
||||
The ``unparse`` method should be the inverse of ``parse`` such
|
||||
that ``assert text == unparse(*parse(text))``.
|
||||
|
||||
See :tl:`MessageEntity` for allowed message entities.
|
||||
"""
|
||||
return self._parse_mode
|
||||
|
||||
@parse_mode.setter
|
||||
def parse_mode(self, mode):
|
||||
self._parse_mode = utils.sanitize_parse_mode(mode)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
async def _parse_message_text(self, message, parse_mode):
|
||||
"""
|
||||
Returns a (parsed message, entities) tuple depending on ``parse_mode``.
|
||||
"""
|
||||
if parse_mode == utils.Default:
|
||||
parse_mode = self._parse_mode
|
||||
else:
|
||||
parse_mode = utils.sanitize_parse_mode(parse_mode)
|
||||
|
||||
if not parse_mode:
|
||||
return message, []
|
||||
|
||||
message, msg_entities = parse_mode.parse(message)
|
||||
for i, e in enumerate(msg_entities):
|
||||
if isinstance(e, types.MessageEntityTextUrl):
|
||||
m = re.match(r'^@|\+|tg://user\?id=(\d+)', e.url)
|
||||
if m:
|
||||
try:
|
||||
msg_entities[i] = types.InputMessageEntityMentionName(
|
||||
e.offset, e.length, await self.get_input_entity(
|
||||
int(m.group(1)) if m.group(1) else e.url
|
||||
)
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
# Make no replacement
|
||||
pass
|
||||
|
||||
return message, msg_entities
|
||||
|
||||
def _get_response_message(self, request, result, input_chat):
|
||||
"""
|
||||
Extracts the response message known a request and Update result.
|
||||
The request may also be the ID of the message to match.
|
||||
"""
|
||||
# Telegram seems to send updateMessageID first, then updateNewMessage,
|
||||
# however let's not rely on that just in case.
|
||||
if isinstance(request, int):
|
||||
msg_id = request
|
||||
else:
|
||||
msg_id = None
|
||||
for update in result.updates:
|
||||
if isinstance(update, types.UpdateMessageID):
|
||||
if update.random_id == request.random_id:
|
||||
msg_id = update.id
|
||||
break
|
||||
|
||||
if isinstance(result, types.UpdateShort):
|
||||
updates = [result.update]
|
||||
entities = {}
|
||||
elif isinstance(result, (types.Updates, types.UpdatesCombined)):
|
||||
updates = result.updates
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in
|
||||
itertools.chain(result.users, result.chats)}
|
||||
else:
|
||||
return
|
||||
|
||||
found = None
|
||||
for update in updates:
|
||||
if isinstance(update, (
|
||||
types.UpdateNewChannelMessage, types.UpdateNewMessage)):
|
||||
if update.message.id == msg_id:
|
||||
found = update.message
|
||||
break
|
||||
|
||||
elif (isinstance(update, types.UpdateEditMessage)
|
||||
and not isinstance(request.peer, types.InputPeerChannel)):
|
||||
if request.id == update.message.id:
|
||||
found = update.message
|
||||
break
|
||||
|
||||
elif (isinstance(update, types.UpdateEditChannelMessage)
|
||||
and utils.get_peer_id(request.peer) ==
|
||||
utils.get_peer_id(update.message.to_id)):
|
||||
if request.id == update.message.id:
|
||||
found = update.message
|
||||
break
|
||||
|
||||
if found:
|
||||
return custom.Message(self, found, entities, input_chat)
|
||||
|
||||
# endregion
|
657
telethon/client/messages.py
Normal file
657
telethon/client/messages.py
Normal file
|
@ -0,0 +1,657 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from collections import UserList
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .messageparse import MessageParseMethods
|
||||
from .uploads import UploadMethods
|
||||
from .. import utils
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageMethods(UploadMethods, MessageParseMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
# region Message retrieval
|
||||
|
||||
@async_generator
|
||||
async def iter_messages(
|
||||
self, entity, limit=None, offset_date=None, offset_id=0,
|
||||
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
||||
from_user=None, batch_size=100, wait_time=None, ids=None,
|
||||
_total=None):
|
||||
"""
|
||||
Iterator over the message history for the specified entity.
|
||||
|
||||
If either `search`, `filter` or `from_user` are provided,
|
||||
:tl:`messages.Search` will be used instead of :tl:`messages.getHistory`.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
The entity from whom to retrieve the message history.
|
||||
|
||||
limit (`int` | `None`, optional):
|
||||
Number of messages to be retrieved. Due to limitations with
|
||||
the API retrieving more than 3000 messages will take longer
|
||||
than half a minute (or even more based on previous calls).
|
||||
The limit may also be ``None``, which would eventually return
|
||||
the whole history.
|
||||
|
||||
offset_date (`datetime`):
|
||||
Offset date (messages *previous* to this date will be
|
||||
retrieved). Exclusive.
|
||||
|
||||
offset_id (`int`):
|
||||
Offset message ID (only messages *previous* to the given
|
||||
ID will be retrieved). Exclusive.
|
||||
|
||||
max_id (`int`):
|
||||
All the messages with a higher (newer) ID or equal to this will
|
||||
be excluded.
|
||||
|
||||
min_id (`int`):
|
||||
All the messages with a lower (older) ID or equal to this will
|
||||
be excluded.
|
||||
|
||||
add_offset (`int`):
|
||||
Additional message offset (all of the specified offsets +
|
||||
this offset = older messages).
|
||||
|
||||
search (`str`):
|
||||
The string to be used as a search query.
|
||||
|
||||
filter (:tl:`MessagesFilter` | `type`):
|
||||
The filter to use when returning messages. For instance,
|
||||
:tl:`InputMessagesFilterPhotos` would yield only messages
|
||||
containing photos.
|
||||
|
||||
from_user (`entity`):
|
||||
Only messages from this user will be returned.
|
||||
|
||||
batch_size (`int`):
|
||||
Messages will be returned in chunks of this size (100 is
|
||||
the maximum). While it makes no sense to modify this value,
|
||||
you are still free to do so.
|
||||
|
||||
wait_time (`int`):
|
||||
Wait time between different :tl:`GetHistoryRequest`. Use this
|
||||
parameter to avoid hitting the ``FloodWaitError`` as needed.
|
||||
If left to ``None``, it will default to 1 second only if
|
||||
the limit is higher than 3000.
|
||||
|
||||
ids (`int`, `list`):
|
||||
A single integer ID (or several IDs) for the message that
|
||||
should be returned. This parameter takes precedence over
|
||||
the rest (which will be ignored if this is set). This can
|
||||
for instance be used to get the message with ID 123 from
|
||||
a channel. Note that if the message doesn't exist, ``None``
|
||||
will appear in its place, so that zipping the list of IDs
|
||||
with the messages can match one-to-one.
|
||||
|
||||
_total (`list`, optional):
|
||||
A single-item list to pass the total parameter by reference.
|
||||
|
||||
Yields:
|
||||
Instances of `telethon.tl.custom.message.Message`.
|
||||
|
||||
Notes:
|
||||
Telegram's flood wait limit for :tl:`GetHistoryRequest` seems to
|
||||
be around 30 seconds per 3000 messages, therefore a sleep of 1
|
||||
second is the default for this limit (or above). You may need
|
||||
an higher limit, so you're free to set the ``batch_size`` that
|
||||
you think may be good.
|
||||
"""
|
||||
# It's possible to get messages by ID without their entity, so only
|
||||
# fetch the input version if we're not using IDs or if it was given.
|
||||
if not ids or entity:
|
||||
entity = await self.get_input_entity(entity)
|
||||
|
||||
if ids:
|
||||
if not utils.is_list_like(ids):
|
||||
ids = (ids,)
|
||||
async for x in self._iter_ids(entity, ids, total=_total):
|
||||
await yield_(x)
|
||||
return
|
||||
|
||||
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
||||
# (starting from last_id - 100), the request will return nothing.
|
||||
#
|
||||
# We can emulate their behaviour locally by setting offset = max_id
|
||||
# and simply stopping once we hit a message with ID <= min_id.
|
||||
offset_id = max(offset_id, max_id)
|
||||
if offset_id and min_id:
|
||||
if offset_id - min_id <= 1:
|
||||
return
|
||||
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if search is not None or filter or from_user:
|
||||
if filter is None:
|
||||
filter = types.InputMessagesFilterEmpty()
|
||||
request = functions.messages.SearchRequest(
|
||||
peer=entity,
|
||||
q=search or '',
|
||||
filter=filter() if isinstance(filter, type) else filter,
|
||||
min_date=None,
|
||||
max_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
add_offset=add_offset,
|
||||
limit=1,
|
||||
max_id=0,
|
||||
min_id=0,
|
||||
hash=0,
|
||||
from_id=(
|
||||
await self.get_input_entity(from_user)
|
||||
if from_user else None
|
||||
)
|
||||
)
|
||||
else:
|
||||
request = functions.messages.GetHistoryRequest(
|
||||
peer=entity,
|
||||
limit=1,
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
min_id=0,
|
||||
max_id=0,
|
||||
add_offset=add_offset,
|
||||
hash=0
|
||||
)
|
||||
|
||||
if limit == 0:
|
||||
if not _total:
|
||||
return
|
||||
# No messages, but we still need to know the total message count
|
||||
result = await self(request)
|
||||
if isinstance(result, types.messages.MessagesNotModified):
|
||||
_total[0] = result.count
|
||||
else:
|
||||
_total[0] = getattr(result, 'count', len(result.messages))
|
||||
return
|
||||
|
||||
if wait_time is None:
|
||||
wait_time = 1 if limit > 3000 else 0
|
||||
|
||||
have = 0
|
||||
last_id = float('inf')
|
||||
batch_size = min(max(batch_size, 1), 100)
|
||||
while have < limit:
|
||||
start = asyncio.get_event_loop().time()
|
||||
# Telegram has a hard limit of 100
|
||||
request.limit = min(limit - have, batch_size)
|
||||
r = await self(request)
|
||||
if _total:
|
||||
_total[0] = getattr(r, 'count', len(r.messages))
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
for message in r.messages:
|
||||
if message.id <= min_id:
|
||||
return
|
||||
|
||||
if isinstance(message, types.MessageEmpty)\
|
||||
or message.id >= last_id:
|
||||
continue
|
||||
|
||||
# There has been reports that on bad connections this method
|
||||
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||
# is an attempt to avoid these duplicates, since the message
|
||||
# IDs are returned in descending order.
|
||||
last_id = message.id
|
||||
|
||||
await yield_(custom.Message(self, message, entities, entity))
|
||||
have += 1
|
||||
|
||||
if len(r.messages) < request.limit:
|
||||
break
|
||||
|
||||
request.offset_id = r.messages[-1].id
|
||||
if isinstance(request, functions.messages.GetHistoryRequest):
|
||||
request.offset_date = r.messages[-1].date
|
||||
else:
|
||||
request.max_date = r.messages[-1].date
|
||||
|
||||
await asyncio.sleep(max(wait_time - (time.time() - start), 0))
|
||||
|
||||
async def get_messages(self, *args, **kwargs):
|
||||
"""
|
||||
Same as :meth:`iter_messages`, but returns a list instead
|
||||
with an additional ``.total`` attribute on the list.
|
||||
|
||||
If the `limit` is not set, it will be 1 by default unless both
|
||||
`min_id` **and** `max_id` are set (as *named* arguments), in
|
||||
which case the entire range will be returned.
|
||||
|
||||
This is so because any integer limit would be rather arbitrary and
|
||||
it's common to only want to fetch one message, but if a range is
|
||||
specified it makes sense that it should return the entirety of it.
|
||||
|
||||
If `ids` is present in the *named* arguments and is not a list,
|
||||
a single :tl:`Message` will be returned for convenience instead
|
||||
of a list.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
if len(args) == 1 and 'limit' not in kwargs:
|
||||
if 'min_id' in kwargs and 'max_id' in kwargs:
|
||||
kwargs['limit'] = None
|
||||
else:
|
||||
kwargs['limit'] = 1
|
||||
|
||||
msgs = UserList()
|
||||
async for x in self.iter_messages(*args, **kwargs):
|
||||
msgs.append(x)
|
||||
msgs.total = total[0]
|
||||
if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
|
||||
return msgs[0]
|
||||
|
||||
return msgs
|
||||
|
||||
async def get_message_history(self, *args, **kwargs):
|
||||
"""Deprecated, see :meth:`get_messages`."""
|
||||
warnings.warn(
|
||||
'get_message_history is deprecated, use get_messages instead'
|
||||
)
|
||||
return await self.get_messages(*args, **kwargs)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Message sending/editing/deleting
|
||||
|
||||
async def send_message(
|
||||
self, entity, message='', reply_to=None,
|
||||
parse_mode=utils.Default, link_preview=True, file=None,
|
||||
force_document=False, clear_draft=False):
|
||||
"""
|
||||
Sends the given message to the specified entity (user/chat/channel).
|
||||
|
||||
The default parse mode is the same as the official applications
|
||||
(a custom flavour of markdown). ``**bold**, `code` or __italic__``
|
||||
are available. In addition you can send ``[links](https://example.com)``
|
||||
and ``[mentions](@username)`` (or using IDs like in the Bot API:
|
||||
``[mention](tg://user?id=123456789)``) and ``pre`` blocks with three
|
||||
backticks.
|
||||
|
||||
Sending a ``/start`` command with a parameter (like ``?start=data``)
|
||||
is also done through this method. Simply send ``'/start data'`` to
|
||||
the bot.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
To who will it be sent.
|
||||
|
||||
message (`str` | :tl:`Message`):
|
||||
The message to be sent, or another message object to resend.
|
||||
|
||||
The maximum length for a message is 35,000 bytes or 4,096
|
||||
characters. Longer messages will not be sliced automatically,
|
||||
and you should slice them manually if the text to send is
|
||||
longer than said length.
|
||||
|
||||
reply_to (`int` | :tl:`Message`, optional):
|
||||
Whether to reply to a message or not. If an integer is provided,
|
||||
it should be the ID of the message that it should reply to.
|
||||
|
||||
parse_mode (`object`, optional):
|
||||
See the `TelegramClient.parse_mode` property for allowed
|
||||
values. Markdown parsing will be used by default.
|
||||
|
||||
link_preview (`bool`, optional):
|
||||
Should the link preview be shown?
|
||||
|
||||
file (`file`, optional):
|
||||
Sends a message with a file attached (e.g. a photo,
|
||||
video, audio or document). The ``message`` may be empty.
|
||||
|
||||
force_document (`bool`, optional):
|
||||
Whether to send the given file as a document or not.
|
||||
|
||||
clear_draft (`bool`, optional):
|
||||
Whether the existing draft should be cleared or not.
|
||||
Has no effect when sending a file.
|
||||
|
||||
Returns:
|
||||
The sent `telethon.tl.custom.message.Message`.
|
||||
"""
|
||||
if file is not None:
|
||||
return await self.send_file(
|
||||
entity, file, caption=message, reply_to=reply_to,
|
||||
parse_mode=parse_mode, force_document=force_document
|
||||
)
|
||||
elif not message:
|
||||
raise ValueError(
|
||||
'The message cannot be empty unless a file is provided'
|
||||
)
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
if isinstance(message, types.Message):
|
||||
if (message.media and not isinstance(
|
||||
message.media, types.MessageMediaWebPage)):
|
||||
return await self.send_file(
|
||||
entity, message.media, caption=message.message,
|
||||
entities=message.entities
|
||||
)
|
||||
|
||||
if reply_to is not None:
|
||||
reply_id = utils.get_message_id(reply_to)
|
||||
elif utils.get_peer_id(entity) == utils.get_peer_id(message.to_id):
|
||||
reply_id = message.reply_to_msg_id
|
||||
else:
|
||||
reply_id = None
|
||||
request = functions.messages.SendMessageRequest(
|
||||
peer=entity,
|
||||
message=message.message or '',
|
||||
silent=message.silent,
|
||||
reply_to_msg_id=reply_id,
|
||||
reply_markup=message.reply_markup,
|
||||
entities=message.entities,
|
||||
clear_draft=clear_draft,
|
||||
no_webpage=not isinstance(
|
||||
message.media, types.MessageMediaWebPage)
|
||||
)
|
||||
message = message.message
|
||||
else:
|
||||
message, msg_ent = await self._parse_message_text(message,
|
||||
parse_mode)
|
||||
request = functions.messages.SendMessageRequest(
|
||||
peer=entity,
|
||||
message=message,
|
||||
entities=msg_ent,
|
||||
no_webpage=not link_preview,
|
||||
reply_to_msg_id=utils.get_message_id(reply_to),
|
||||
clear_draft=clear_draft
|
||||
)
|
||||
|
||||
result = await self(request)
|
||||
if isinstance(result, types.UpdateShortSentMessage):
|
||||
to_id, cls = utils.resolve_id(utils.get_peer_id(entity))
|
||||
return custom.Message(self, types.Message(
|
||||
id=result.id,
|
||||
to_id=cls(to_id),
|
||||
message=message,
|
||||
date=result.date,
|
||||
out=result.out,
|
||||
media=result.media,
|
||||
entities=result.entities
|
||||
), {}, input_chat=entity)
|
||||
|
||||
return self._get_response_message(request, result, entity)
|
||||
|
||||
async def forward_messages(self, entity, messages, from_peer=None):
|
||||
"""
|
||||
Forwards the given message(s) to the specified entity.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
To which entity the message(s) will be forwarded.
|
||||
|
||||
messages (`list` | `int` | :tl:`Message`):
|
||||
The message(s) to forward, or their integer IDs.
|
||||
|
||||
from_peer (`entity`):
|
||||
If the given messages are integer IDs and not instances
|
||||
of the ``Message`` class, this *must* be specified in
|
||||
order for the forward to work.
|
||||
|
||||
Returns:
|
||||
The list of forwarded `telethon.tl.custom.message.Message`,
|
||||
or a single one if a list wasn't provided as input.
|
||||
"""
|
||||
single = not utils.is_list_like(messages)
|
||||
if single:
|
||||
messages = (messages,)
|
||||
|
||||
if not from_peer:
|
||||
try:
|
||||
# On private chats (to_id = PeerUser), if the message is
|
||||
# not outgoing, we actually need to use "from_id" to get
|
||||
# the conversation on which the message was sent.
|
||||
from_peer = next(
|
||||
m.from_id
|
||||
if not m.out and isinstance(m.to_id, types.PeerUser)
|
||||
else m.to_id for m in messages
|
||||
if isinstance(m, types.Message)
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
'from_chat must be given if integer IDs are used'
|
||||
)
|
||||
|
||||
req = functions.messages.ForwardMessagesRequest(
|
||||
from_peer=from_peer,
|
||||
id=[m if isinstance(m, int) else m.id for m in messages],
|
||||
to_peer=entity
|
||||
)
|
||||
result = await self(req)
|
||||
if isinstance(result, (types.Updates, types.UpdatesCombined)):
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(result.users, result.chats)}
|
||||
else:
|
||||
entities = {}
|
||||
|
||||
random_to_id = {}
|
||||
id_to_message = {}
|
||||
for update in result.updates:
|
||||
if isinstance(update, types.UpdateMessageID):
|
||||
random_to_id[update.random_id] = update.id
|
||||
elif isinstance(update, (
|
||||
types.UpdateNewMessage, types.UpdateNewChannelMessage)):
|
||||
id_to_message[update.message.id] = custom.Message(
|
||||
self, update.message, entities, input_chat=entity)
|
||||
|
||||
result = [id_to_message[random_to_id[rnd]] for rnd in req.random_id]
|
||||
return result[0] if single else result
|
||||
|
||||
async def edit_message(
|
||||
self, entity, message=None, text=None, parse_mode=utils.Default,
|
||||
link_preview=True, file=None):
|
||||
"""
|
||||
Edits the given message ID (to change its contents or disable preview).
|
||||
|
||||
Args:
|
||||
entity (`entity` | :tl:`Message`):
|
||||
From which chat to edit the message. This can also be
|
||||
the message to be edited, and the entity will be inferred
|
||||
from it, so the next parameter will be assumed to be the
|
||||
message text.
|
||||
|
||||
message (`int` | :tl:`Message` | `str`):
|
||||
The ID of the message (or :tl:`Message` itself) to be edited.
|
||||
If the `entity` was a :tl:`Message`, then this message will be
|
||||
treated as the new text.
|
||||
|
||||
text (`str`, optional):
|
||||
The new text of the message. Does nothing if the `entity`
|
||||
was a :tl:`Message`.
|
||||
|
||||
parse_mode (`object`, optional):
|
||||
See the `TelegramClient.parse_mode` property for allowed
|
||||
values. Markdown parsing will be used by default.
|
||||
|
||||
link_preview (`bool`, optional):
|
||||
Should the link preview be shown?
|
||||
|
||||
file (`str` | `bytes` | `file` | `media`, optional):
|
||||
The file object that should replace the existing media
|
||||
in the message.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> client = ...
|
||||
>>> message = client.send_message('username', 'hello')
|
||||
>>>
|
||||
>>> client.edit_message('username', message, 'hello!')
|
||||
>>> # or
|
||||
>>> client.edit_message('username', message.id, 'Hello')
|
||||
>>> # or
|
||||
>>> client.edit_message(message, 'Hello!')
|
||||
|
||||
Raises:
|
||||
``MessageAuthorRequiredError`` if you're not the author of the
|
||||
message but tried editing it anyway.
|
||||
|
||||
``MessageNotModifiedError`` if the contents of the message were
|
||||
not modified at all.
|
||||
|
||||
Returns:
|
||||
The edited `telethon.tl.custom.message.Message`.
|
||||
"""
|
||||
if isinstance(entity, types.Message):
|
||||
text = message # Shift the parameters to the right
|
||||
message = entity
|
||||
entity = entity.to_id
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
text, msg_entities = await self._parse_message_text(text, parse_mode)
|
||||
file_handle, media = await self._file_to_media(file)
|
||||
request = functions.messages.EditMessageRequest(
|
||||
peer=entity,
|
||||
id=utils.get_message_id(message),
|
||||
message=text,
|
||||
no_webpage=not link_preview,
|
||||
entities=msg_entities,
|
||||
media=media
|
||||
)
|
||||
msg = self._get_response_message(request, self(request), entity)
|
||||
self._cache_media(msg, file, file_handle)
|
||||
return msg
|
||||
|
||||
async def delete_messages(self, entity, message_ids, revoke=True):
|
||||
"""
|
||||
Deletes a message from a chat, optionally "for everyone".
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
From who the message will be deleted. This can actually
|
||||
be ``None`` for normal chats, but **must** be present
|
||||
for channels and megagroups.
|
||||
|
||||
message_ids (`list` | `int` | :tl:`Message`):
|
||||
The IDs (or ID) or messages to be deleted.
|
||||
|
||||
revoke (`bool`, optional):
|
||||
Whether the message should be deleted for everyone or not.
|
||||
By default it has the opposite behaviour of official clients,
|
||||
and it will delete the message for everyone.
|
||||
This has no effect on channels or megagroups.
|
||||
|
||||
Returns:
|
||||
A list of :tl:`AffectedMessages`, each item being the result
|
||||
for the delete calls of the messages in chunks of 100 each.
|
||||
"""
|
||||
if not utils.is_list_like(message_ids):
|
||||
message_ids = (message_ids,)
|
||||
|
||||
message_ids = (
|
||||
m.id if isinstance(m, (
|
||||
types.Message, types.MessageService, types.MessageEmpty))
|
||||
else int(m) for m in message_ids
|
||||
)
|
||||
|
||||
entity = await self.get_input_entity(entity) if entity else None
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
return await self([functions.channels.DeleteMessagesRequest(
|
||||
entity, list(c)) for c in utils.chunks(message_ids)])
|
||||
else:
|
||||
return await self([functions.messages.DeleteMessagesRequest(
|
||||
list(c), revoke) for c in utils.chunks(message_ids)])
|
||||
|
||||
# endregion
|
||||
|
||||
# region Miscellaneous
|
||||
|
||||
async def send_read_acknowledge(self, entity, message=None, max_id=None,
|
||||
clear_mentions=False):
|
||||
"""
|
||||
Sends a "read acknowledge" (i.e., notifying the given peer that we've
|
||||
read their messages, also known as the "double check").
|
||||
|
||||
This effectively marks a message as read (or more than one) in the
|
||||
given conversation.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
The chat where these messages are located.
|
||||
|
||||
message (`list` | :tl:`Message`):
|
||||
Either a list of messages or a single message.
|
||||
|
||||
max_id (`int`):
|
||||
Overrides messages, until which message should the
|
||||
acknowledge should be sent.
|
||||
|
||||
clear_mentions (`bool`):
|
||||
Whether the mention badge should be cleared (so that
|
||||
there are no more mentions) or not for the given entity.
|
||||
|
||||
If no message is provided, this will be the only action
|
||||
taken.
|
||||
"""
|
||||
if max_id is None:
|
||||
if message:
|
||||
if utils.is_list_like(message):
|
||||
max_id = max(msg.id for msg in message)
|
||||
else:
|
||||
max_id = message.id
|
||||
elif not clear_mentions:
|
||||
raise ValueError(
|
||||
'Either a message list or a max_id must be provided.')
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
if clear_mentions:
|
||||
await self(functions.messages.ReadMentionsRequest(entity))
|
||||
if max_id is None:
|
||||
return True
|
||||
|
||||
if max_id is not None:
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
return await self(functions.channels.ReadHistoryRequest(
|
||||
entity, max_id=max_id))
|
||||
else:
|
||||
return await self(functions.messages.ReadHistoryRequest(
|
||||
entity, max_id=max_id))
|
||||
|
||||
return False
|
||||
|
||||
# endregion
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
@async_generator
|
||||
async def _iter_ids(self, entity, ids, total):
|
||||
"""
|
||||
Special case for `iter_messages` when it should only fetch some IDs.
|
||||
"""
|
||||
if total:
|
||||
total[0] = len(ids)
|
||||
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
r = await self(functions.channels.GetMessagesRequest(entity, ids))
|
||||
else:
|
||||
r = await self(functions.messages.GetMessagesRequest(ids))
|
||||
|
||||
if isinstance(r, types.messages.MessagesNotModified):
|
||||
for _ in ids:
|
||||
await yield_(None)
|
||||
return
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
# Telegram seems to return the messages in the order in which
|
||||
# we asked them for, so we don't need to check it ourselves.
|
||||
for message in r.messages:
|
||||
if isinstance(message, types.MessageEmpty):
|
||||
await yield_(None)
|
||||
else:
|
||||
await yield_(custom.Message(self, message, entities, entity))
|
||||
|
||||
# endregion
|
370
telethon/client/telegrambaseclient.py
Normal file
370
telethon/client/telegrambaseclient.py
Normal file
|
@ -0,0 +1,370 @@
|
|||
import abc
|
||||
import logging
|
||||
import platform
|
||||
import warnings
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from .. import version
|
||||
from ..crypto import rsa
|
||||
from ..extensions import markdown
|
||||
from ..network import MTProtoSender, ConnectionTcpFull
|
||||
from ..network.mtprotostate import MTProtoState
|
||||
from ..sessions import Session, SQLiteSession
|
||||
from ..tl import TLObject, functions
|
||||
from ..tl.all_tlobjects import LAYER
|
||||
|
||||
DEFAULT_DC_ID = 4
|
||||
DEFAULT_IPV4_IP = '149.154.167.51'
|
||||
DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]'
|
||||
DEFAULT_PORT = 443
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramBaseClient(abc.ABC):
|
||||
"""
|
||||
This is the abstract base class for the client. It defines some
|
||||
basic stuff like connecting, switching data center, etc, and
|
||||
leaves the `__call__` unimplemented.
|
||||
|
||||
Args:
|
||||
session (`str` | `telethon.sessions.abstract.Session`, `None`):
|
||||
The file name of the session file to be used if a string is
|
||||
given (it may be a full path), or the Session instance to be
|
||||
used otherwise. If it's ``None``, the session will not be saved,
|
||||
and you should call :meth:`.log_out()` when you're done.
|
||||
|
||||
Note that if you pass a string it will be a file in the current
|
||||
working directory, although you can also pass absolute paths.
|
||||
|
||||
The session file contains enough information for you to login
|
||||
without re-sending the code, so if you have to enter the code
|
||||
more than once, maybe you're changing the working directory,
|
||||
renaming or removing the file, or using random names.
|
||||
|
||||
api_id (`int` | `str`):
|
||||
The API ID you obtained from https://my.telegram.org.
|
||||
|
||||
api_hash (`str`):
|
||||
The API ID you obtained from https://my.telegram.org.
|
||||
|
||||
connection (`telethon.network.connection.common.Connection`, optional):
|
||||
The connection instance to be used when creating a new connection
|
||||
to the servers. If it's a type, the `proxy` argument will be used.
|
||||
|
||||
Defaults to `telethon.network.connection.tcpfull.ConnectionTcpFull`.
|
||||
|
||||
use_ipv6 (`bool`, optional):
|
||||
Whether to connect to the servers through IPv6 or not.
|
||||
By default this is ``False`` as IPv6 support is not
|
||||
too widespread yet.
|
||||
|
||||
proxy (`tuple` | `dict`, optional):
|
||||
A tuple consisting of ``(socks.SOCKS5, 'host', port)``.
|
||||
See https://github.com/Anorov/PySocks#usage-1 for more.
|
||||
|
||||
timeout (`int` | `float` | `timedelta`, optional):
|
||||
The timeout to be used when receiving responses from
|
||||
the network. Defaults to 5 seconds.
|
||||
|
||||
report_errors (`bool`, optional):
|
||||
Whether to report RPC errors or not. Defaults to ``True``,
|
||||
see :ref:`api-status` for more information.
|
||||
|
||||
device_model (`str`, optional):
|
||||
"Device model" to be sent when creating the initial connection.
|
||||
Defaults to ``platform.node()``.
|
||||
|
||||
system_version (`str`, optional):
|
||||
"System version" to be sent when creating the initial connection.
|
||||
Defaults to ``platform.system()``.
|
||||
|
||||
app_version (`str`, optional):
|
||||
"App version" to be sent when creating the initial connection.
|
||||
Defaults to `telethon.version.__version__`.
|
||||
|
||||
lang_code (`str`, optional):
|
||||
"Language code" to be sent when creating the initial connection.
|
||||
Defaults to ``'en'``.
|
||||
|
||||
system_lang_code (`str`, optional):
|
||||
"System lang code" to be sent when creating the initial connection.
|
||||
Defaults to `lang_code`.
|
||||
"""
|
||||
|
||||
# Current TelegramClient version
|
||||
__version__ = version.__version__
|
||||
|
||||
# Cached server configuration (with .dc_options), can be "global"
|
||||
_config = None
|
||||
_cdn_config = None
|
||||
|
||||
# region Initialization
|
||||
|
||||
def __init__(self, session, api_id, api_hash,
|
||||
*,
|
||||
connection=ConnectionTcpFull,
|
||||
use_ipv6=False,
|
||||
proxy=None,
|
||||
timeout=timedelta(seconds=5),
|
||||
report_errors=True,
|
||||
device_model=None,
|
||||
system_version=None,
|
||||
app_version=None,
|
||||
lang_code='en',
|
||||
system_lang_code='en'):
|
||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||
if not api_id or not api_hash:
|
||||
raise ValueError(
|
||||
"Your API ID or Hash cannot be empty or None. "
|
||||
"Refer to telethon.rtfd.io for more information.")
|
||||
|
||||
self._use_ipv6 = use_ipv6
|
||||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
session = SQLiteSession(session)
|
||||
elif not isinstance(session, Session):
|
||||
raise TypeError(
|
||||
'The given session must be a str or a Session instance.'
|
||||
)
|
||||
|
||||
# ':' in session.server_address is True if it's an IPv6 address
|
||||
if (not session.server_address or
|
||||
(':' in session.server_address) != use_ipv6):
|
||||
session.set_dc(
|
||||
DEFAULT_DC_ID,
|
||||
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
|
||||
DEFAULT_PORT
|
||||
)
|
||||
|
||||
session.report_errors = report_errors
|
||||
self.session = session
|
||||
self.api_id = int(api_id)
|
||||
self.api_hash = api_hash
|
||||
|
||||
# This is the main sender, which will be used from the thread
|
||||
# that calls .connect(). Every other thread will spawn a new
|
||||
# temporary connection. The connection on this one is always
|
||||
# kept open so Telegram can send us updates.
|
||||
if isinstance(connection, type):
|
||||
connection = connection(proxy=proxy, timeout=timeout)
|
||||
|
||||
# Used on connection. Capture the variables in a lambda since
|
||||
# exporting clients need to create this InvokeWithLayerRequest.
|
||||
system = platform.uname()
|
||||
self._init_with = lambda x: functions.InvokeWithLayerRequest(
|
||||
LAYER, functions.InitConnectionRequest(
|
||||
api_id=self.api_id,
|
||||
device_model=device_model or system.system or 'Unknown',
|
||||
system_version=system_version or system.release or '1.0',
|
||||
app_version=app_version or self.__version__,
|
||||
lang_code=lang_code,
|
||||
system_lang_code=system_lang_code,
|
||||
lang_pack='', # "langPacks are for official apps only"
|
||||
query=x
|
||||
)
|
||||
)
|
||||
|
||||
state = MTProtoState(self.session.auth_key)
|
||||
self._connection = connection
|
||||
self._sender = MTProtoSender(
|
||||
state, connection,
|
||||
first_query=self._init_with(functions.help.GetConfigRequest()),
|
||||
update_callback=self._handle_update
|
||||
)
|
||||
|
||||
# Cache :tl:`ExportedAuthorization` as ``dc_id: MTProtoState``
|
||||
# to easily import them when getting an exported sender.
|
||||
self._exported_auths = {}
|
||||
|
||||
# Save whether the user is authorized here (a.k.a. logged in)
|
||||
self._authorized = None # None = We don't know yet
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
# Also have another delay for GetStateRequest.
|
||||
#
|
||||
# If the connection is kept alive for long without invoking any
|
||||
# high level request the server simply stops sending updates.
|
||||
# TODO maybe we can have ._last_request instead if any req works?
|
||||
self._last_state = datetime.now()
|
||||
self._state_delay = timedelta(hours=1)
|
||||
|
||||
# Some further state for subclasses
|
||||
self._event_builders = []
|
||||
self._events_pending_resolve = []
|
||||
|
||||
# Default parse mode
|
||||
self._parse_mode = markdown
|
||||
|
||||
# Some fields to easy signing in. Let {phone: hash} be
|
||||
# a dictionary because the user may change their mind.
|
||||
self._phone_code_hash = {}
|
||||
self._phone = None
|
||||
self._tos = None
|
||||
|
||||
# Sometimes we need to know who we are, cache the self peer
|
||||
self._self_input_peer = None
|
||||
|
||||
# endregion
|
||||
|
||||
# region Connecting
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Connects to Telegram.
|
||||
"""
|
||||
had_auth = self.session.auth_key is not None
|
||||
await self._sender.connect(
|
||||
self.session.server_address, self.session.port)
|
||||
|
||||
if not had_auth:
|
||||
self.session.auth_key = self._sender.state.auth_key
|
||||
self.session.save()
|
||||
|
||||
def is_connected(self):
|
||||
"""
|
||||
Returns ``True`` if the user has connected.
|
||||
"""
|
||||
return self._sender.is_connected()
|
||||
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Disconnects from Telegram.
|
||||
"""
|
||||
await self._sender.disconnect()
|
||||
# TODO What to do with the update state? Does it belong here?
|
||||
# self.session.set_update_state(0, self.updates.get_update_state(0))
|
||||
self.session.close()
|
||||
|
||||
async def _switch_dc(self, new_dc):
|
||||
"""
|
||||
Permanently switches the current connection to the new data center.
|
||||
"""
|
||||
__log__.info('Reconnecting to new data center %s', new_dc)
|
||||
dc = await self._get_dc(new_dc)
|
||||
|
||||
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
# auth_key's are associated with a server, which has now changed
|
||||
# so it's not valid anymore. Set to None to force recreating it.
|
||||
self.session.auth_key = self._sender.state.auth_key = None
|
||||
self.session.save()
|
||||
await self.disconnect()
|
||||
return await self.connect()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Working with different connections/Data Centers
|
||||
|
||||
async def _get_dc(self, dc_id, cdn=False):
|
||||
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
||||
cls = self.__class__
|
||||
if not cls._config:
|
||||
cls._config = await self(functions.help.GetConfigRequest())
|
||||
|
||||
if cdn and not self._cdn_config:
|
||||
cls._cdn_config = await self(functions.help.GetCdnConfigRequest())
|
||||
for pk in cls._cdn_config.public_keys:
|
||||
rsa.add_key(pk.public_key)
|
||||
|
||||
return next(
|
||||
dc for dc in cls._config.dc_options
|
||||
if dc.id == dc_id
|
||||
and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
|
||||
)
|
||||
|
||||
async def _get_exported_sender(self, dc_id):
|
||||
"""
|
||||
Returns a cached `MTProtoSender` for the given `dc_id`, or creates
|
||||
a new one if it doesn't exist yet, and imports a freshly exported
|
||||
authorization key for it to be usable.
|
||||
"""
|
||||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization
|
||||
auth = self._exported_auths.get(dc_id)
|
||||
dc = await self._get_dc(dc_id)
|
||||
state = MTProtoState(auth)
|
||||
# TODO Don't hardcode ConnectionTcpFull()
|
||||
# Can't reuse self._sender._connection as it has its own seqno.
|
||||
#
|
||||
# If one were to do that, Telegram would reset the connection
|
||||
# with no further clues.
|
||||
sender = MTProtoSender(state, ConnectionTcpFull())
|
||||
await sender.connect(dc.ip_address, dc.port)
|
||||
if not auth:
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
|
||||
req = self._init_with(functions.auth.ImportAuthorizationRequest(
|
||||
id=auth.id, bytes=auth.bytes
|
||||
))
|
||||
await sender.send(req)
|
||||
self._exported_auths[dc_id] = sender.state.auth_key
|
||||
|
||||
return sender
|
||||
|
||||
async def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._get_exported_client, but for CDNs"""
|
||||
# TODO Implement
|
||||
raise NotImplementedError
|
||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
if not session:
|
||||
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = self.session.clone()
|
||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||
|
||||
__log__.info('Creating new CDN client')
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
|
||||
# This will make use of the new RSA keys for this specific CDN.
|
||||
#
|
||||
# We won't be calling GetConfigRequest because it's only called
|
||||
# when needed by ._get_dc, and also it's static so it's likely
|
||||
# set already. Avoid invoking non-CDN methods by not syncing updates.
|
||||
client.connect(_sync_updates=False)
|
||||
client._authorized = self._authorized
|
||||
return client
|
||||
|
||||
# endregion
|
||||
|
||||
# region Invoking Telegram requests
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, request, retries=5, ordered=False):
|
||||
"""
|
||||
Invokes (sends) one or more MTProtoRequests and returns (receives)
|
||||
their result.
|
||||
|
||||
Args:
|
||||
request (`TLObject` | `list`):
|
||||
The request or requests to be invoked.
|
||||
|
||||
ordered (`bool`, optional):
|
||||
Whether the requests (if more than one was given) should be
|
||||
executed sequentially on the server. They run in arbitrary
|
||||
order by default.
|
||||
|
||||
Returns:
|
||||
The result of the request (often a `TLObject`) or a list of
|
||||
results if more than one request was given.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Let people use client.invoke(SomeRequest()) instead client(...)
|
||||
async def invoke(self, *args, **kwargs):
|
||||
warnings.warn('client.invoke(...) is deprecated, '
|
||||
'use client(...) instead')
|
||||
return await self(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _handle_update(self, update):
|
||||
raise NotImplementedError
|
||||
|
||||
# endregion
|
13
telethon/client/telegramclient.py
Normal file
13
telethon/client/telegramclient.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from . import (
|
||||
UpdateMethods, AuthMethods, DownloadMethods, DialogMethods,
|
||||
ChatMethods, MessageMethods, UploadMethods, MessageParseMethods,
|
||||
UserMethods
|
||||
)
|
||||
|
||||
|
||||
class TelegramClient(
|
||||
UpdateMethods, AuthMethods, DownloadMethods, DialogMethods,
|
||||
ChatMethods, MessageMethods, UploadMethods, MessageParseMethods,
|
||||
UserMethods
|
||||
):
|
||||
pass
|
179
telethon/client/updates.py
Normal file
179
telethon/client/updates.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import events, utils
|
||||
from ..tl import types, functions
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
def on(self, event):
|
||||
"""
|
||||
Decorator helper method around add_event_handler().
|
||||
|
||||
Args:
|
||||
event (`_EventBuilder` | `type`):
|
||||
The event builder class or instance to be used,
|
||||
for instance ``events.NewMessage``.
|
||||
"""
|
||||
def decorator(f):
|
||||
self.add_event_handler(f, event)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, callback, event=None):
|
||||
"""
|
||||
Registers the given callback to be called on the specified event.
|
||||
|
||||
Args:
|
||||
callback (`callable`):
|
||||
The callable function accepting one parameter to be used.
|
||||
|
||||
event (`_EventBuilder` | `type`, optional):
|
||||
The event builder class or instance to be used,
|
||||
for instance ``events.NewMessage``.
|
||||
|
||||
If left unspecified, `telethon.events.raw.Raw` (the
|
||||
:tl:`Update` objects with no further processing) will
|
||||
be passed instead.
|
||||
"""
|
||||
if isinstance(event, type):
|
||||
event = event()
|
||||
elif not event:
|
||||
event = events.Raw()
|
||||
|
||||
self._events_pending_resolve.append(event)
|
||||
self._event_builders.append((event, callback))
|
||||
|
||||
def remove_event_handler(self, callback, event=None):
|
||||
"""
|
||||
Inverse operation of :meth:`add_event_handler`.
|
||||
|
||||
If no event is given, all events for this callback are removed.
|
||||
Returns how many callbacks were removed.
|
||||
"""
|
||||
found = 0
|
||||
if event and not isinstance(event, type):
|
||||
event = type(event)
|
||||
|
||||
i = len(self._event_builders)
|
||||
while i:
|
||||
i -= 1
|
||||
ev, cb = self._event_builders[i]
|
||||
if cb == callback and (not event or isinstance(ev, event)):
|
||||
del self._event_builders[i]
|
||||
found += 1
|
||||
|
||||
return found
|
||||
|
||||
def list_event_handlers(self):
|
||||
"""
|
||||
Lists all added event handlers, returning a list of pairs
|
||||
consisting of (callback, event).
|
||||
"""
|
||||
return [(callback, event) for event, callback in self._event_builders]
|
||||
|
||||
def add_update_handler(self, handler):
|
||||
"""Deprecated, see :meth:`add_event_handler`."""
|
||||
warnings.warn(
|
||||
'add_update_handler is deprecated, use the @client.on syntax '
|
||||
'or add_event_handler(callback, events.Raw) instead (see '
|
||||
'https://telethon.rtfd.io/en/latest/extra/basic/working-'
|
||||
'with-updates.html)'
|
||||
)
|
||||
return self.add_event_handler(handler, events.Raw)
|
||||
|
||||
def remove_update_handler(self, handler):
|
||||
return self.remove_event_handler(handler)
|
||||
|
||||
def list_update_handlers(self):
|
||||
return [callback for callback, _ in self.list_event_handlers()]
|
||||
|
||||
async def catch_up(self):
|
||||
state = self.session.get_update_state(0)
|
||||
if not state or not state.pts:
|
||||
return
|
||||
|
||||
self.session.catching_up = True
|
||||
try:
|
||||
while True:
|
||||
d = await self(functions.updates.GetDifferenceRequest(
|
||||
state.pts, state.date, state.qts))
|
||||
if isinstance(d, types.updates.DifferenceEmpty):
|
||||
state.date = d.date
|
||||
state.seq = d.seq
|
||||
break
|
||||
elif isinstance(d, (types.updates.DifferenceSlice,
|
||||
types.updates.Difference)):
|
||||
if isinstance(d, types.updates.Difference):
|
||||
state = d.state
|
||||
elif d.intermediate_state.pts > state.pts:
|
||||
state = d.intermediate_state
|
||||
else:
|
||||
# TODO Figure out why other applications can rely on
|
||||
# using always the intermediate_state to eventually
|
||||
# reach a DifferenceEmpty, but that leads to an
|
||||
# infinite loop here (so check against old pts to stop)
|
||||
break
|
||||
|
||||
self._handle_update(types.Updates(
|
||||
users=d.users,
|
||||
chats=d.chats,
|
||||
date=state.date,
|
||||
seq=state.seq,
|
||||
updates=d.other_updates + [
|
||||
types.UpdateNewMessage(m, 0, 0)
|
||||
for m in d.new_messages
|
||||
]
|
||||
))
|
||||
elif isinstance(d, types.updates.DifferenceTooLong):
|
||||
break
|
||||
finally:
|
||||
self.session.set_update_state(0, state)
|
||||
self.session.catching_up = False
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
def _handle_update(self, update):
|
||||
asyncio.ensure_future(self._dispatch_update(update))
|
||||
|
||||
async def _dispatch_update(self, update):
|
||||
if self._events_pending_resolve:
|
||||
# TODO Add lock not to resolve them twice
|
||||
for event in self._events_pending_resolve:
|
||||
await event.resolve(self)
|
||||
self._events_pending_resolve.clear()
|
||||
|
||||
for builder, callback in self._event_builders:
|
||||
event = builder.build(update)
|
||||
if event:
|
||||
if hasattr(event, '_set_client'):
|
||||
event._set_client(self)
|
||||
else:
|
||||
event._client = self
|
||||
|
||||
event.original_update = update
|
||||
try:
|
||||
await callback(event)
|
||||
except events.StopPropagation:
|
||||
__log__.debug(
|
||||
"Event handler '{}' stopped chain of "
|
||||
"propagation for event {}."
|
||||
.format(callback.__name__,
|
||||
type(event).__name__)
|
||||
)
|
||||
break
|
||||
except:
|
||||
__log__.exception('Unhandled exception on {}'
|
||||
.format(callback.__name__))
|
||||
|
||||
# endregion
|
485
telethon/client/uploads.py
Normal file
485
telethon/client/uploads.py
Normal file
|
@ -0,0 +1,485 @@
|
|||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from mimetypes import guess_type
|
||||
|
||||
from .messageparse import MessageParseMethods
|
||||
from .users import UserMethods
|
||||
from .. import utils, helpers
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
try:
|
||||
import hachoir
|
||||
import hachoir.metadata
|
||||
import hachoir.parser
|
||||
except ImportError:
|
||||
hachoir = None
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadMethods(MessageParseMethods, UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
async def send_file(
|
||||
self, entity, file, caption='', force_document=False,
|
||||
progress_callback=None, reply_to=None, attributes=None,
|
||||
thumb=None, allow_cache=True, parse_mode=utils.Default,
|
||||
voice_note=False, video_note=False, **kwargs):
|
||||
"""
|
||||
Sends a file to the specified entity.
|
||||
|
||||
Args:
|
||||
entity (`entity`):
|
||||
Who will receive the file.
|
||||
|
||||
file (`str` | `bytes` | `file` | `media`):
|
||||
The path of the file, byte array, or stream that will be sent.
|
||||
Note that if a byte array or a stream is given, a filename
|
||||
or its type won't be inferred, and it will be sent as an
|
||||
"unnamed application/octet-stream".
|
||||
|
||||
Furthermore the file may be any media (a message, document,
|
||||
photo or similar) so that it can be resent without the need
|
||||
to download and re-upload it again.
|
||||
|
||||
If a list or similar is provided, the files in it will be
|
||||
sent as an album in the order in which they appear, sliced
|
||||
in chunks of 10 if more than 10 are given.
|
||||
|
||||
caption (`str`, optional):
|
||||
Optional caption for the sent media message.
|
||||
|
||||
force_document (`bool`, optional):
|
||||
If left to ``False`` and the file is a path that ends with
|
||||
the extension of an image file or a video file, it will be
|
||||
sent as such. Otherwise always as a document.
|
||||
|
||||
progress_callback (`callable`, optional):
|
||||
A callback function accepting two parameters:
|
||||
``(sent bytes, total)``.
|
||||
|
||||
reply_to (`int` | :tl:`Message`):
|
||||
Same as `reply_to` from `send_message`.
|
||||
|
||||
attributes (`list`, optional):
|
||||
Optional attributes that override the inferred ones, like
|
||||
:tl:`DocumentAttributeFilename` and so on.
|
||||
|
||||
thumb (`str` | `bytes` | `file`, optional):
|
||||
Optional thumbnail (for videos).
|
||||
|
||||
allow_cache (`bool`, optional):
|
||||
Whether to allow using the cached version stored in the
|
||||
database or not. Defaults to ``True`` to avoid re-uploads.
|
||||
Must be ``False`` if you wish to use different attributes
|
||||
or thumb than those that were used when the file was cached.
|
||||
|
||||
parse_mode (`object`, optional):
|
||||
See the `TelegramClient.parse_mode` property for allowed
|
||||
values. Markdown parsing will be used by default.
|
||||
|
||||
voice_note (`bool`, optional):
|
||||
If ``True`` the audio will be sent as a voice note.
|
||||
|
||||
Set `allow_cache` to ``False`` if you sent the same file
|
||||
without this setting before for it to work.
|
||||
|
||||
video_note (`bool`, optional):
|
||||
If ``True`` the video will be sent as a video note,
|
||||
also known as a round video message.
|
||||
|
||||
Set `allow_cache` to ``False`` if you sent the same file
|
||||
without this setting before for it to work.
|
||||
|
||||
Notes:
|
||||
If the ``hachoir3`` package (``hachoir`` module) is installed,
|
||||
it will be used to determine metadata from audio and video files.
|
||||
|
||||
Returns:
|
||||
The `telethon.tl.custom.message.Message` (or messages) containing
|
||||
the sent file, or messages if a list of them was passed.
|
||||
"""
|
||||
# First check if the user passed an iterable, in which case
|
||||
# we may want to send as an album if all are photo files.
|
||||
if utils.is_list_like(file):
|
||||
# TODO Fix progress_callback
|
||||
images = []
|
||||
if force_document:
|
||||
documents = file
|
||||
else:
|
||||
documents = []
|
||||
for x in file:
|
||||
if utils.is_image(x):
|
||||
images.append(x)
|
||||
else:
|
||||
documents.append(x)
|
||||
|
||||
result = []
|
||||
while images:
|
||||
result += await self._send_album(
|
||||
entity, images[:10], caption=caption,
|
||||
progress_callback=progress_callback, reply_to=reply_to,
|
||||
parse_mode=parse_mode
|
||||
)
|
||||
images = images[10:]
|
||||
|
||||
result.extend(
|
||||
await self.send_file(
|
||||
entity, x, allow_cache=allow_cache,
|
||||
caption=caption, force_document=force_document,
|
||||
progress_callback=progress_callback, reply_to=reply_to,
|
||||
attributes=attributes, thumb=thumb, voice_note=voice_note,
|
||||
video_note=video_note, **kwargs
|
||||
) for x in documents
|
||||
)
|
||||
return result
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
reply_to = utils.get_message_id(reply_to)
|
||||
|
||||
# Not document since it's subject to change.
|
||||
# Needed when a Message is passed to send_message and it has media.
|
||||
if 'entities' in kwargs:
|
||||
msg_entities = kwargs['entities']
|
||||
else:
|
||||
caption, msg_entities =\
|
||||
await self._parse_message_text(caption, parse_mode)
|
||||
|
||||
file_handle, media = await self._file_to_media(
|
||||
file, allow_cache=allow_cache)
|
||||
|
||||
request = functions.messages.SendMediaRequest(
|
||||
entity, media, reply_to_msg_id=reply_to, message=caption,
|
||||
entities=msg_entities
|
||||
)
|
||||
msg = self._get_response_message(request, await self(request), entity)
|
||||
self._cache_media(msg, file, file_handle, force_document=force_document)
|
||||
|
||||
return msg
|
||||
|
||||
async def send_voice_note(self, *args, **kwargs):
|
||||
"""Deprecated, see :meth:`send_file`."""
|
||||
warnings.warn('send_voice_note is deprecated, use '
|
||||
'send_file(..., voice_note=True) instead')
|
||||
kwargs['is_voice_note'] = True
|
||||
return await self.send_file(*args, **kwargs)
|
||||
|
||||
async def _send_album(self, entity, files, caption='',
|
||||
progress_callback=None, reply_to=None,
|
||||
parse_mode=utils.Default):
|
||||
"""Specialized version of .send_file for albums"""
|
||||
# We don't care if the user wants to avoid cache, we will use it
|
||||
# anyway. Why? The cached version will be exactly the same thing
|
||||
# we need to produce right now to send albums (uploadMedia), and
|
||||
# cache only makes a difference for documents where the user may
|
||||
# want the attributes used on them to change.
|
||||
#
|
||||
# In theory documents can be sent inside the albums but they appear
|
||||
# as different messages (not inside the album), and the logic to set
|
||||
# the attributes/avoid cache is already written in .send_file().
|
||||
entity = await self.get_input_entity(entity)
|
||||
if not utils.is_list_like(caption):
|
||||
caption = (caption,)
|
||||
captions = [
|
||||
await self._parse_message_text(caption or '', parse_mode)
|
||||
for caption in reversed(caption) # Pop from the end (so reverse)
|
||||
]
|
||||
reply_to = utils.get_message_id(reply_to)
|
||||
|
||||
# Need to upload the media first, but only if they're not cached yet
|
||||
media = []
|
||||
for file in files:
|
||||
# fh will either be InputPhoto or a modified InputFile
|
||||
fh = await self.upload_file(file, use_cache=types.InputPhoto)
|
||||
if not isinstance(fh, types.InputPhoto):
|
||||
r = await self(functions.messages.UploadMediaRequest(
|
||||
entity, media=types.InputMediaUploadedPhoto(fh)
|
||||
))
|
||||
input_photo = utils.get_input_photo(r.photo)
|
||||
self.session.cache_file(fh.md5, fh.size, input_photo)
|
||||
fh = input_photo
|
||||
|
||||
if captions:
|
||||
caption, msg_entities = captions.pop()
|
||||
else:
|
||||
caption, msg_entities = '', None
|
||||
media.append(types.InputSingleMedia(types.InputMediaPhoto(fh), message=caption,
|
||||
entities=msg_entities))
|
||||
|
||||
# Now we can construct the multi-media request
|
||||
result = await self(functions.messages.SendMultiMediaRequest(
|
||||
entity, reply_to_msg_id=reply_to, multi_media=media
|
||||
))
|
||||
return [
|
||||
self._get_response_message(update.id, result, entity)
|
||||
for update in result.updates
|
||||
if isinstance(update, types.UpdateMessageID)
|
||||
]
|
||||
|
||||
async def upload_file(
|
||||
self, file, part_size_kb=None, file_name=None, use_cache=None,
|
||||
progress_callback=None):
|
||||
"""
|
||||
Uploads the specified file and returns a handle (an instance of
|
||||
:tl:`InputFile` or :tl:`InputFileBig`, as required) which can be
|
||||
later used before it expires (they are usable during less than a day).
|
||||
|
||||
Uploading a file will simply return a "handle" to the file stored
|
||||
remotely in the Telegram servers, which can be later used on. This
|
||||
will **not** upload the file to your own chat or any chat at all.
|
||||
|
||||
Args:
|
||||
file (`str` | `bytes` | `file`):
|
||||
The path of the file, byte array, or stream that will be sent.
|
||||
Note that if a byte array or a stream is given, a filename
|
||||
or its type won't be inferred, and it will be sent as an
|
||||
"unnamed application/octet-stream".
|
||||
|
||||
part_size_kb (`int`, optional):
|
||||
Chunk size when uploading files. The larger, the less
|
||||
requests will be made (up to 512KB maximum).
|
||||
|
||||
file_name (`str`, optional):
|
||||
The file name which will be used on the resulting InputFile.
|
||||
If not specified, the name will be taken from the ``file``
|
||||
and if this is not a ``str``, it will be ``"unnamed"``.
|
||||
|
||||
use_cache (`type`, optional):
|
||||
The type of cache to use (currently either :tl:`InputDocument`
|
||||
or :tl:`InputPhoto`). If present and the file is small enough
|
||||
to need the MD5, it will be checked against the database,
|
||||
and if a match is found, the upload won't be made. Instead,
|
||||
an instance of type ``use_cache`` will be returned.
|
||||
|
||||
progress_callback (`callable`, optional):
|
||||
A callback function accepting two parameters:
|
||||
``(sent bytes, total)``.
|
||||
|
||||
Returns:
|
||||
:tl:`InputFileBig` if the file size is larger than 10MB,
|
||||
`telethon.tl.custom.input_sized_file.InputSizedFile`
|
||||
(subclass of :tl:`InputFile`) otherwise.
|
||||
"""
|
||||
if isinstance(file, (types.InputFile, types.InputFileBig)):
|
||||
return file # Already uploaded
|
||||
|
||||
if isinstance(file, str):
|
||||
file_size = os.path.getsize(file)
|
||||
elif isinstance(file, bytes):
|
||||
file_size = len(file)
|
||||
else:
|
||||
file = file.read()
|
||||
file_size = len(file)
|
||||
|
||||
# File will now either be a string or bytes
|
||||
if not part_size_kb:
|
||||
part_size_kb = utils.get_appropriated_part_size(file_size)
|
||||
|
||||
if part_size_kb > 512:
|
||||
raise ValueError('The part size must be less or equal to 512KB')
|
||||
|
||||
part_size = int(part_size_kb * 1024)
|
||||
if part_size % 1024 != 0:
|
||||
raise ValueError(
|
||||
'The part size must be evenly divisible by 1024')
|
||||
|
||||
# Set a default file name if None was specified
|
||||
file_id = helpers.generate_random_long()
|
||||
if not file_name:
|
||||
if isinstance(file, str):
|
||||
file_name = os.path.basename(file)
|
||||
else:
|
||||
file_name = str(file_id)
|
||||
|
||||
# Determine whether the file is too big (over 10MB) or not
|
||||
# Telegram does make a distinction between smaller or larger files
|
||||
is_large = file_size > 10 * 1024 * 1024
|
||||
hash_md5 = hashlib.md5()
|
||||
if not is_large:
|
||||
# Calculate the MD5 hash before anything else.
|
||||
# As this needs to be done always for small files,
|
||||
# might as well do it before anything else and
|
||||
# check the cache.
|
||||
if isinstance(file, str):
|
||||
with open(file, 'rb') as stream:
|
||||
file = stream.read()
|
||||
hash_md5.update(file)
|
||||
if use_cache:
|
||||
cached = self.session.get_file(
|
||||
hash_md5.digest(), file_size, cls=use_cache
|
||||
)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
part_count = (file_size + part_size - 1) // part_size
|
||||
__log__.info('Uploading file of %d bytes in %d chunks of %d',
|
||||
file_size, part_count, part_size)
|
||||
|
||||
with open(file, 'rb') if isinstance(file, str) else BytesIO(file)\
|
||||
as stream:
|
||||
for part_index in range(part_count):
|
||||
# Read the file by in chunks of size part_size
|
||||
part = stream.read(part_size)
|
||||
|
||||
# The SavePartRequest is different depending on whether
|
||||
# the file is too large or not (over or less than 10MB)
|
||||
if is_large:
|
||||
request = functions.upload.SaveBigFilePartRequest(
|
||||
file_id, part_index, part_count, part)
|
||||
else:
|
||||
request = functions.upload.SaveFilePartRequest(
|
||||
file_id, part_index, part)
|
||||
|
||||
result = await self(request)
|
||||
if result:
|
||||
__log__.debug('Uploaded %d/%d', part_index + 1,
|
||||
part_count)
|
||||
if progress_callback:
|
||||
progress_callback(stream.tell(), file_size)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Failed to upload file part {}.'.format(part_index))
|
||||
|
||||
if is_large:
|
||||
return types.InputFileBig(file_id, part_count, file_name)
|
||||
else:
|
||||
return custom.InputSizedFile(
|
||||
file_id, part_count, file_name, md5=hash_md5, size=file_size
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
async def _file_to_media(
|
||||
self, file, force_document=False,
|
||||
progress_callback=None, attributes=None, thumb=None,
|
||||
allow_cache=True, voice_note=False, video_note=False):
|
||||
if not file:
|
||||
return None, None
|
||||
|
||||
if not isinstance(file, (str, bytes, io.IOBase)):
|
||||
# The user may pass a Message containing media (or the media,
|
||||
# or anything similar) that should be treated as a file. Try
|
||||
# getting the input media for whatever they passed and send it.
|
||||
try:
|
||||
return None, utils.get_input_media(file)
|
||||
except TypeError:
|
||||
return None, None # Can't turn whatever was given into media
|
||||
|
||||
as_image = utils.is_image(file) and not force_document
|
||||
use_cache = types.InputPhoto if as_image else types.InputDocument
|
||||
file_handle = await self.upload_file(
|
||||
file, progress_callback=progress_callback,
|
||||
use_cache=use_cache if allow_cache else None
|
||||
)
|
||||
|
||||
if isinstance(file_handle, use_cache):
|
||||
# File was cached, so an instance of use_cache was returned
|
||||
if as_image:
|
||||
media = types.InputMediaPhoto(file_handle)
|
||||
else:
|
||||
media = types.InputMediaDocument(file_handle)
|
||||
elif as_image:
|
||||
media = types.InputMediaUploadedPhoto(file_handle)
|
||||
else:
|
||||
mime_type = None
|
||||
if isinstance(file, str):
|
||||
# Determine mime-type and attributes
|
||||
# Take the first element by using [0] since it returns a tuple
|
||||
mime_type = guess_type(file)[0]
|
||||
attr_dict = {
|
||||
types.DocumentAttributeFilename:
|
||||
types.DocumentAttributeFilename(
|
||||
os.path.basename(file))
|
||||
}
|
||||
if utils.is_audio(file) and hachoir:
|
||||
m = hachoir.metadata.extractMetadata(
|
||||
hachoir.parser.createParser(file)
|
||||
)
|
||||
attr_dict[types.DocumentAttributeAudio] = \
|
||||
types.DocumentAttributeAudio(
|
||||
voice=voice_note,
|
||||
title=m.get('title') if m.has(
|
||||
'title') else None,
|
||||
performer=m.get('author') if m.has(
|
||||
'author') else None,
|
||||
duration=int(m.get('duration').seconds
|
||||
if m.has('duration') else 0)
|
||||
)
|
||||
|
||||
if not force_document and utils.is_video(file):
|
||||
if hachoir:
|
||||
m = hachoir.metadata.extractMetadata(
|
||||
hachoir.parser.createParser(file)
|
||||
)
|
||||
doc = types.DocumentAttributeVideo(
|
||||
round_message=video_note,
|
||||
w=m.get('width') if m.has('width') else 0,
|
||||
h=m.get('height') if m.has('height') else 0,
|
||||
duration=int(m.get('duration').seconds
|
||||
if m.has('duration') else 0)
|
||||
)
|
||||
else:
|
||||
doc = types.DocumentAttributeVideo(
|
||||
0, 1, 1, round_message=video_note)
|
||||
|
||||
attr_dict[types.DocumentAttributeVideo] = doc
|
||||
else:
|
||||
attr_dict = {
|
||||
types.DocumentAttributeFilename:
|
||||
types.DocumentAttributeFilename(
|
||||
os.path.basename(
|
||||
getattr(file, 'name',
|
||||
None) or 'unnamed'))
|
||||
}
|
||||
|
||||
if voice_note:
|
||||
if types.DocumentAttributeAudio in attr_dict:
|
||||
attr_dict[types.DocumentAttributeAudio].voice = True
|
||||
else:
|
||||
attr_dict[types.DocumentAttributeAudio] = \
|
||||
types.DocumentAttributeAudio(0, voice=True)
|
||||
|
||||
# Now override the attributes if any. As we have a dict of
|
||||
# {cls: instance}, we can override any class with the list
|
||||
# of attributes provided by the user easily.
|
||||
if attributes:
|
||||
for a in attributes:
|
||||
attr_dict[type(a)] = a
|
||||
|
||||
# Ensure we have a mime type, any; but it cannot be None
|
||||
# 'The "octet-stream" subtype is used to indicate that a body
|
||||
# contains arbitrary binary data.'
|
||||
if not mime_type:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
input_kw = {}
|
||||
if thumb:
|
||||
input_kw['thumb'] = await self.upload_file(thumb)
|
||||
|
||||
media = types.InputMediaUploadedDocument(
|
||||
file=file_handle,
|
||||
mime_type=mime_type,
|
||||
attributes=list(attr_dict.values()),
|
||||
**input_kw
|
||||
)
|
||||
return file_handle, media
|
||||
|
||||
def _cache_media(self, msg, file, file_handle,
|
||||
force_document=False):
|
||||
if file and msg and isinstance(file_handle,
|
||||
custom.InputSizedFile):
|
||||
# There was a response message and we didn't use cached
|
||||
# version, so cache whatever we just sent to the database.
|
||||
md5, size = file_handle.md5, file_handle.size
|
||||
if utils.is_image(file) and not force_document:
|
||||
to_cache = utils.get_input_photo(msg.media.photo)
|
||||
else:
|
||||
to_cache = utils.get_input_document(msg.media.document)
|
||||
self.session.cache_file(md5, size, to_cache)
|
||||
|
||||
# endregion
|
264
telethon/client/users.py
Normal file
264
telethon/client/users.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
|
||||
from .telegrambaseclient import TelegramBaseClient
|
||||
from .. import errors, utils
|
||||
from ..tl import TLObject, TLRequest, types, functions
|
||||
|
||||
|
||||
_NOT_A_REQUEST = TypeError('You can only invoke requests, not types!')
|
||||
|
||||
|
||||
class UserMethods(TelegramBaseClient):
|
||||
async def __call__(self, request, retries=5, ordered=False):
|
||||
for r in (request if utils.is_list_like(request) else (request,)):
|
||||
if not isinstance(r, TLRequest):
|
||||
raise _NOT_A_REQUEST
|
||||
await r.resolve(self, utils)
|
||||
|
||||
for _ in range(retries):
|
||||
try:
|
||||
future = self._sender.send(request, ordered=ordered)
|
||||
if isinstance(future, list):
|
||||
results = []
|
||||
for f in future:
|
||||
results.append(await f)
|
||||
return results
|
||||
else:
|
||||
return await future
|
||||
except (errors.ServerError, errors.RpcCallFailError):
|
||||
pass
|
||||
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
|
||||
if e.seconds <= self.session.flood_sleep_threshold:
|
||||
await asyncio.sleep(e.seconds)
|
||||
else:
|
||||
raise
|
||||
except (errors.PhoneMigrateError, errors.NetworkMigrateError,
|
||||
errors.UserMigrateError) as e:
|
||||
await self._switch_dc(e.new_dc)
|
||||
|
||||
raise ValueError('Number of retries reached 0')
|
||||
|
||||
# region Public methods
|
||||
|
||||
async def get_me(self, input_peer=False):
|
||||
"""
|
||||
Gets "me" (the self user) which is currently authenticated,
|
||||
or None if the request fails (hence, not authenticated).
|
||||
|
||||
Args:
|
||||
input_peer (`bool`, optional):
|
||||
Whether to return the :tl:`InputPeerUser` version or the normal
|
||||
:tl:`User`. This can be useful if you just need to know the ID
|
||||
of yourself.
|
||||
|
||||
Returns:
|
||||
Your own :tl:`User`.
|
||||
"""
|
||||
if input_peer and self._self_input_peer:
|
||||
return self._self_input_peer
|
||||
|
||||
try:
|
||||
me = (await self(
|
||||
functions.users.GetUsersRequest([types.InputUserSelf()])))[0]
|
||||
|
||||
if not self._self_input_peer:
|
||||
self._self_input_peer = utils.get_input_peer(
|
||||
me, allow_self=False
|
||||
)
|
||||
|
||||
return self._self_input_peer if input_peer else me
|
||||
except errors.UnauthorizedError:
|
||||
return None
|
||||
|
||||
async def get_entity(self, entity):
|
||||
"""
|
||||
Turns the given entity into a valid Telegram user or chat.
|
||||
|
||||
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
|
||||
The entity (or iterable of entities) to be transformed.
|
||||
If it's a string which can be converted to an integer or starts
|
||||
with '+' it will be resolved as if it were a phone number.
|
||||
|
||||
If it doesn't start with '+' or starts with a '@' it will be
|
||||
be resolved from the username. If no exact match is returned,
|
||||
an error will be raised.
|
||||
|
||||
If the entity is an integer or a Peer, its information will be
|
||||
returned through a call to self.get_input_peer(entity).
|
||||
|
||||
If the entity is neither, and it's not a TLObject, an
|
||||
error will be raised.
|
||||
|
||||
Returns:
|
||||
:tl:`User`, :tl:`Chat` or :tl:`Channel` corresponding to the
|
||||
input entity. A list will be returned if more than one was given.
|
||||
"""
|
||||
single = not utils.is_list_like(entity)
|
||||
if single:
|
||||
entity = (entity,)
|
||||
|
||||
# Group input entities by string (resolve username),
|
||||
# input users (get users), input chat (get chats) and
|
||||
# input channels (get channels) to get the most entities
|
||||
# in the less amount of calls possible.
|
||||
inputs = [
|
||||
x if isinstance(x, str) else await self.get_input_entity(x)
|
||||
for x in entity
|
||||
]
|
||||
users = [x for x in inputs
|
||||
if isinstance(x, (types.InputPeerUser, types.InputPeerSelf))]
|
||||
chats = [x.chat_id for x in inputs
|
||||
if isinstance(x, types.InputPeerChat)]
|
||||
channels = [x for x in inputs
|
||||
if isinstance(x, types.InputPeerChannel)]
|
||||
if users:
|
||||
# GetUsersRequest has a limit of 200 per call
|
||||
tmp = []
|
||||
while users:
|
||||
curr, users = users[:200], users[200:]
|
||||
tmp.extend(await self(functions.users.GetUsersRequest(curr)))
|
||||
users = tmp
|
||||
if chats: # TODO Handle chats slice?
|
||||
chats = (await self(
|
||||
functions.messages.GetChatsRequest(chats))).chats
|
||||
if channels:
|
||||
channels = (await self(
|
||||
functions.channels.GetChannelsRequest(channels))).chats
|
||||
|
||||
# Merge users, chats and channels into a single dictionary
|
||||
id_entity = {
|
||||
utils.get_peer_id(x): x
|
||||
for x in itertools.chain(users, chats, channels)
|
||||
}
|
||||
|
||||
# We could check saved usernames and put them into the users,
|
||||
# chats and channels list from before. While this would reduce
|
||||
# the amount of ResolveUsername calls, it would fail to catch
|
||||
# username changes.
|
||||
result = [
|
||||
await self._get_entity_from_string(x) if isinstance(x, str)
|
||||
else (
|
||||
id_entity[utils.get_peer_id(x)]
|
||||
if not isinstance(x, types.InputPeerSelf)
|
||||
else next(u for u in id_entity.values()
|
||||
if isinstance(u, types.User) and u.is_self)
|
||||
)
|
||||
for x in inputs
|
||||
]
|
||||
return result[0] if single else result
|
||||
|
||||
async def get_input_entity(self, peer):
|
||||
"""
|
||||
Turns the given peer into its input entity version. Most requests
|
||||
use this kind of InputUser, InputChat and so on, so this is the
|
||||
most suitable call to make for those cases.
|
||||
|
||||
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
|
||||
The integer ID of an user or otherwise either of a
|
||||
:tl:`PeerUser`, :tl:`PeerChat` or :tl:`PeerChannel`, for
|
||||
which to get its ``Input*`` version.
|
||||
|
||||
If this ``Peer`` hasn't been seen before by the library, the top
|
||||
dialogs will be loaded and their entities saved to the session
|
||||
file (unless this feature was disabled explicitly).
|
||||
|
||||
If in the end the access hash required for the peer was not found,
|
||||
a ValueError will be raised.
|
||||
|
||||
Returns:
|
||||
:tl:`InputPeerUser`, :tl:`InputPeerChat` or :tl:`InputPeerChannel`
|
||||
or :tl:`InputPeerSelf` if the parameter is ``'me'`` or ``'self'``.
|
||||
|
||||
If you need to get the ID of yourself, you should use
|
||||
`get_me` with ``input_peer=True``) instead.
|
||||
"""
|
||||
if peer in ('me', 'self'):
|
||||
return types.InputPeerSelf()
|
||||
|
||||
try:
|
||||
# First try to get the entity from cache, otherwise figure it out
|
||||
return self.session.get_input_entity(peer)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if isinstance(peer, str):
|
||||
return utils.get_input_peer(
|
||||
await self._get_entity_from_string(peer))
|
||||
|
||||
if not isinstance(peer, int) and (not isinstance(peer, TLObject)
|
||||
or peer.SUBCLASS_OF_ID != 0x2d45687):
|
||||
# Try casting the object into an input peer. Might TypeError.
|
||||
# Don't do it if a not-found ID was given (instead ValueError).
|
||||
# Also ignore Peer (0x2d45687 == crc32(b'Peer'))'s, lacking hash.
|
||||
return utils.get_input_peer(peer)
|
||||
|
||||
raise ValueError(
|
||||
'Could not find the input entity for "{}". Please read https://'
|
||||
'telethon.readthedocs.io/en/latest/extra/basic/entities.html to'
|
||||
' find out more details.'
|
||||
.format(peer)
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
async def _get_entity_from_string(self, string):
|
||||
"""
|
||||
Gets a full entity from the given string, which may be a phone or
|
||||
an username, and processes all the found entities on the session.
|
||||
The string may also be a user link, or a channel/chat invite link.
|
||||
|
||||
This method has the side effect of adding the found users to the
|
||||
session database, so it can be queried later without API calls,
|
||||
if this option is enabled on the session.
|
||||
|
||||
Returns the found entity, or raises TypeError if not found.
|
||||
"""
|
||||
phone = utils.parse_phone(string)
|
||||
if phone:
|
||||
for user in (await self(
|
||||
functions.contacts.GetContactsRequest(0))).users:
|
||||
if user.phone == phone:
|
||||
return user
|
||||
else:
|
||||
username, is_join_chat = utils.parse_username(string)
|
||||
if is_join_chat:
|
||||
invite = await self(
|
||||
functions.messages.CheckChatInviteRequest(username))
|
||||
|
||||
if isinstance(invite, types.ChatInvite):
|
||||
raise ValueError(
|
||||
'Cannot get entity from a channel (or group) '
|
||||
'that you are not part of. Join the group and retry'
|
||||
)
|
||||
elif isinstance(invite, types.ChatInviteAlready):
|
||||
return invite.chat
|
||||
elif username:
|
||||
if username in ('me', 'self'):
|
||||
return await self.get_me()
|
||||
|
||||
try:
|
||||
result = await self(
|
||||
functions.contacts.ResolveUsernameRequest(username))
|
||||
except errors.UsernameNotOccupiedError as e:
|
||||
raise ValueError('No user has "{}" as username'
|
||||
.format(username)) from e
|
||||
|
||||
for entity in itertools.chain(result.users, result.chats):
|
||||
if getattr(entity, 'username', None) or '' \
|
||||
.lower() == username:
|
||||
return entity
|
||||
try:
|
||||
# Nobody with this username, maybe it's an exact name/title
|
||||
return await self.get_entity(
|
||||
self.session.get_input_entity(string))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise ValueError(
|
||||
'Cannot find any entity corresponding to "{}"'.format(string)
|
||||
)
|
||||
|
||||
# endregion
|
|
@ -37,4 +37,4 @@ class AuthKey:
|
|||
data = new_nonce + struct.pack('<BQ', number, self.aux_hash)
|
||||
|
||||
# Calculates the message key from the given data
|
||||
return sha1(data).digest()[4:20]
|
||||
return int.from_bytes(sha1(data).digest()[4:20], 'little', signed=True)
|
||||
|
|
|
@ -30,7 +30,7 @@ class CdnDecrypter:
|
|||
self.cdn_file_hashes = cdn_file_hashes
|
||||
|
||||
@staticmethod
|
||||
def prepare_decrypter(client, cdn_client, cdn_redirect):
|
||||
async def prepare_decrypter(client, cdn_client, cdn_redirect):
|
||||
"""
|
||||
Prepares a new CDN decrypter.
|
||||
|
||||
|
@ -52,14 +52,14 @@ class CdnDecrypter:
|
|||
cdn_aes, cdn_redirect.cdn_file_hashes
|
||||
)
|
||||
|
||||
cdn_file = cdn_client(GetCdnFileRequest(
|
||||
cdn_file = await cdn_client(GetCdnFileRequest(
|
||||
file_token=cdn_redirect.file_token,
|
||||
offset=cdn_redirect.cdn_file_hashes[0].offset,
|
||||
limit=cdn_redirect.cdn_file_hashes[0].limit
|
||||
))
|
||||
if isinstance(cdn_file, CdnFileReuploadNeeded):
|
||||
# We need to use the original client here
|
||||
client(ReuploadCdnFileRequest(
|
||||
await client(ReuploadCdnFileRequest(
|
||||
file_token=cdn_redirect.file_token,
|
||||
request_token=cdn_file.request_token
|
||||
))
|
||||
|
|
|
@ -40,49 +40,31 @@ def report_error(code, message, report_method):
|
|||
"We really don't want to crash when just reporting an error"
|
||||
|
||||
|
||||
def rpc_message_to_error(code, message, report_method=None):
|
||||
def rpc_message_to_error(rpc_error, report_method=None):
|
||||
"""
|
||||
Converts a Telegram's RPC Error to a Python error.
|
||||
|
||||
:param code: the integer code of the error (like 400).
|
||||
:param message: the message representing the error.
|
||||
:param rpc_error: the RpcError instance.
|
||||
:param report_method: if present, the ID of the method that caused it.
|
||||
:return: the RPCError as a Python exception that represents this error.
|
||||
"""
|
||||
if report_method is not None:
|
||||
Thread(
|
||||
target=report_error,
|
||||
args=(code, message, report_method)
|
||||
args=(rpc_error.error_code, rpc_error.error_message, report_method)
|
||||
).start()
|
||||
|
||||
# Try to get the error by direct look-up, otherwise regex
|
||||
# TODO Maybe regexes could live in a separate dictionary?
|
||||
cls = rpc_errors_all.get(message, None)
|
||||
cls = rpc_errors_all.get(rpc_error.error_message, None)
|
||||
if cls:
|
||||
return cls()
|
||||
|
||||
for msg_regex, cls in rpc_errors_all.items():
|
||||
m = re.match(msg_regex, message)
|
||||
m = re.match(msg_regex, rpc_error.error_message)
|
||||
if m:
|
||||
capture = int(m.group(1)) if m.groups() else None
|
||||
return cls(capture=capture)
|
||||
|
||||
if code == 400:
|
||||
return BadRequestError(message)
|
||||
|
||||
if code == 401:
|
||||
return UnauthorizedError(message)
|
||||
|
||||
if code == 403:
|
||||
return ForbiddenError(message)
|
||||
|
||||
if code == 404:
|
||||
return NotFoundError(message)
|
||||
|
||||
if code == 406:
|
||||
return AuthKeyError(message)
|
||||
|
||||
if code == 500:
|
||||
return ServerError(message)
|
||||
|
||||
return RPCError('{} (code {})'.format(message, code))
|
||||
cls = base_errors.get(rpc_error.error_code, RPCError)
|
||||
return cls(rpc_error.error_message)
|
||||
|
|
|
@ -12,14 +12,15 @@ class TypeNotFoundError(Exception):
|
|||
Occurs when a type is not found, for example,
|
||||
when trying to read a TLObject with an invalid constructor code.
|
||||
"""
|
||||
def __init__(self, invalid_constructor_id):
|
||||
def __init__(self, invalid_constructor_id, remaining):
|
||||
super().__init__(
|
||||
'Could not find a matching Constructor ID for the TLObject '
|
||||
'that was supposed to be read with ID {}. Most likely, a TLObject '
|
||||
'was trying to be read when it should not be read.'
|
||||
.format(hex(invalid_constructor_id)))
|
||||
'that was supposed to be read with ID {:08x}. Most likely, '
|
||||
'a TLObject was trying to be read when it should not be read. '
|
||||
'Remaining bytes: {!r}'.format(invalid_constructor_id, remaining))
|
||||
|
||||
self.invalid_constructor_id = invalid_constructor_id
|
||||
self.remaining = remaining
|
||||
|
||||
|
||||
class InvalidChecksumError(Exception):
|
||||
|
|
|
@ -97,6 +97,19 @@ class ServerError(RPCError):
|
|||
self.message = message
|
||||
|
||||
|
||||
class BotTimeout(RPCError):
|
||||
"""
|
||||
Clicking the inline buttons of bots that never (or take to long to)
|
||||
call ``answerCallbackQuery`` will result in this "special" RPCError.
|
||||
"""
|
||||
code = -503
|
||||
message = 'Timeout'
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BadMessageError(Exception):
|
||||
"""Occurs when handling a bad_message_notification."""
|
||||
ErrorMessages = {
|
||||
|
@ -142,3 +155,9 @@ class BadMessageError(Exception):
|
|||
'Unknown error code (this should not happen): {}.'.format(code)))
|
||||
|
||||
self.code = code
|
||||
|
||||
|
||||
base_errors = {x.code: x for x in (
|
||||
InvalidDCError, BadRequestError, UnauthorizedError, ForbiddenError,
|
||||
NotFoundError, AuthKeyError, FloodError, ServerError, BotTimeout
|
||||
)}
|
||||
|
|
|
@ -7,7 +7,7 @@ from ..errors import RPCError
|
|||
from ..tl import TLObject, types, functions
|
||||
|
||||
|
||||
def _into_id_set(client, chats):
|
||||
async def _into_id_set(client, chats):
|
||||
"""Helper util to turn the input chat or chats into a set of IDs."""
|
||||
if chats is None:
|
||||
return None
|
||||
|
@ -30,9 +30,9 @@ def _into_id_set(client, chats):
|
|||
# 0x2d45687 == crc32(b'Peer')
|
||||
result.add(utils.get_peer_id(chat))
|
||||
else:
|
||||
chat = client.get_input_entity(chat)
|
||||
chat = await client.get_input_entity(chat)
|
||||
if isinstance(chat, types.InputPeerSelf):
|
||||
chat = client.get_me(input_peer=True)
|
||||
chat = await client.get_me(input_peer=True)
|
||||
result.add(utils.get_peer_id(chat))
|
||||
|
||||
return result
|
||||
|
@ -62,10 +62,10 @@ class EventBuilder(abc.ABC):
|
|||
def build(self, update):
|
||||
"""Builds an event for the given update if possible, or returns None"""
|
||||
|
||||
def resolve(self, client):
|
||||
async def resolve(self, client):
|
||||
"""Helper method to allow event builders to be resolved before usage"""
|
||||
self.chats = _into_id_set(client, self.chats)
|
||||
self._self_id = client.get_me(input_peer=True).user_id
|
||||
self.chats = await _into_id_set(client, self.chats)
|
||||
self._self_id = await client.get_me(input_peer=True).user_id
|
||||
|
||||
def _filter_event(self, event):
|
||||
"""
|
||||
|
|
|
@ -22,7 +22,7 @@ class Raw(EventBuilder):
|
|||
assert all(isinstance(x, type) for x in types)
|
||||
self.types = tuple(types)
|
||||
|
||||
def resolve(self, client):
|
||||
async def resolve(self, client):
|
||||
pass
|
||||
|
||||
def build(self, update):
|
||||
|
|
|
@ -8,6 +8,7 @@ from struct import unpack
|
|||
|
||||
from ..errors import TypeNotFoundError
|
||||
from ..tl.all_tlobjects import tlobjects
|
||||
from ..tl.core import core_objects
|
||||
|
||||
|
||||
class BinaryReader:
|
||||
|
@ -136,9 +137,14 @@ class BinaryReader:
|
|||
elif value == 0x1cb5c415: # Vector
|
||||
return [self.tgread_object() for _ in range(self.read_int())]
|
||||
|
||||
# If there was still no luck, give up
|
||||
self.seek(-4) # Go back
|
||||
raise TypeNotFoundError(constructor_id)
|
||||
clazz = core_objects.get(constructor_id, None)
|
||||
if clazz is None:
|
||||
# If there was still no luck, give up
|
||||
self.seek(-4) # Go back
|
||||
pos = self.tell_position()
|
||||
error = TypeNotFoundError(constructor_id, self.read())
|
||||
self.set_position(pos)
|
||||
raise error
|
||||
|
||||
return clazz.from_reader(self)
|
||||
|
||||
|
|
|
@ -1,31 +1,44 @@
|
|||
"""
|
||||
This module holds a rough implementation of the C# TCP client.
|
||||
|
||||
This class is **not** safe across several tasks since partial reads
|
||||
may be ``await``'ed before being able to return the exact byte count.
|
||||
|
||||
This class is also not concerned about disconnections or retries of
|
||||
any sort, nor any other kind of errors such as connecting twice.
|
||||
"""
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from io import BytesIO, BufferedWriter
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
|
||||
CONN_RESET_ERRNOS = {
|
||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
|
||||
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
|
||||
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
|
||||
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
|
||||
}
|
||||
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
|
||||
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
|
||||
|
||||
try:
|
||||
import socks
|
||||
except ImportError:
|
||||
socks = None
|
||||
|
||||
MAX_TIMEOUT = 15 # in seconds
|
||||
CONN_RESET_ERRNOS = {
|
||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||
errno.EINVAL, errno.ENOTCONN
|
||||
}
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TcpClient:
|
||||
"""A simple TCP client to ease the work with sockets and proxies."""
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
|
||||
class SocketClosed(ConnectionError):
|
||||
pass
|
||||
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||
"""
|
||||
Initializes the TCP client.
|
||||
|
||||
|
@ -34,31 +47,34 @@ class TcpClient:
|
|||
"""
|
||||
self.proxy = proxy
|
||||
self._socket = None
|
||||
self._closing_lock = Lock()
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._closed = asyncio.Event(loop=self._loop)
|
||||
self._closed.set()
|
||||
|
||||
if isinstance(timeout, timedelta):
|
||||
self.timeout = timeout.seconds
|
||||
elif isinstance(timeout, (int, float)):
|
||||
if isinstance(timeout, (int, float)):
|
||||
self.timeout = float(timeout)
|
||||
elif hasattr(timeout, 'seconds'):
|
||||
self.timeout = float(timeout.seconds)
|
||||
else:
|
||||
raise TypeError('Invalid timeout type: {}'.format(type(timeout)))
|
||||
|
||||
def _recreate_socket(self, mode):
|
||||
if self.proxy is None:
|
||||
self._socket = socket.socket(mode, socket.SOCK_STREAM)
|
||||
@staticmethod
|
||||
def _create_socket(mode, proxy):
|
||||
if proxy is None:
|
||||
s = socket.socket(mode, socket.SOCK_STREAM)
|
||||
else:
|
||||
import socks
|
||||
self._socket = socks.socksocket(mode, socket.SOCK_STREAM)
|
||||
if type(self.proxy) is dict:
|
||||
self._socket.set_proxy(**self.proxy)
|
||||
s = socks.socksocket(mode, socket.SOCK_STREAM)
|
||||
if isinstance(proxy, dict):
|
||||
s.set_proxy(**proxy)
|
||||
else: # tuple, list, etc.
|
||||
self._socket.set_proxy(*self.proxy)
|
||||
s.set_proxy(*proxy)
|
||||
s.setblocking(False)
|
||||
return s
|
||||
|
||||
self._socket.settimeout(self.timeout)
|
||||
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
"""
|
||||
Tries connecting forever to IP:port unless an OSError is raised.
|
||||
Tries connecting to IP:port unless an OSError is raised.
|
||||
|
||||
:param ip: the IP to connect to.
|
||||
:param port: the port to connect to.
|
||||
|
@ -69,136 +85,162 @@ class TcpClient:
|
|||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
|
||||
timeout = 1
|
||||
while True:
|
||||
try:
|
||||
while not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
|
||||
self._socket.connect(address)
|
||||
break # Successful connection, stop retrying to connect
|
||||
except OSError as e:
|
||||
__log__.info('OSError "%s" raised while connecting', e)
|
||||
# Stop retrying to connect if proxy connection error occurred
|
||||
if socks and isinstance(e, socks.ProxyConnectionError):
|
||||
raise
|
||||
# There are some errors that we know how to handle, and
|
||||
# the loop will allow us to retry
|
||||
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL,
|
||||
errno.ECONNREFUSED, # Windows-specific follow
|
||||
getattr(errno, 'WSAEACCES', None)):
|
||||
# Bad file descriptor, i.e. socket was closed, set it
|
||||
# to none to recreate it on the next iteration
|
||||
self._socket = None
|
||||
time.sleep(timeout)
|
||||
timeout *= 2
|
||||
if timeout > MAX_TIMEOUT:
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_connected(self):
|
||||
"""Determines whether the client is connected or not."""
|
||||
return self._socket is not None and self._socket.fileno() >= 0
|
||||
|
||||
connected = property(fget=_get_connected)
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection."""
|
||||
if self._closing_lock.locked():
|
||||
# Already closing, no need to close again (avoid None.close())
|
||||
return
|
||||
|
||||
with self._closing_lock:
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
self._socket.close()
|
||||
except OSError:
|
||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||
finally:
|
||||
self._socket = None
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
Writes (sends) the specified bytes to the connected peer.
|
||||
|
||||
:param data: the data to send.
|
||||
"""
|
||||
if self._socket is None:
|
||||
self._raise_connection_reset(None)
|
||||
|
||||
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
|
||||
# The socket timeout is now the maximum total duration to send all data.
|
||||
try:
|
||||
self._socket.sendall(data)
|
||||
except socket.timeout as e:
|
||||
__log__.debug('socket.timeout "%s" while writing data', e)
|
||||
raise TimeoutError() from e
|
||||
except ConnectionError as e:
|
||||
__log__.info('ConnectionError "%s" while writing data', e)
|
||||
self._raise_connection_reset(e)
|
||||
if self._socket is None:
|
||||
self._socket = self._create_socket(mode, self.proxy)
|
||||
|
||||
await asyncio.wait_for(
|
||||
self._loop.sock_connect(self._socket, address),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
self._closed.clear()
|
||||
except OSError as e:
|
||||
__log__.info('OSError "%s" while writing data', e)
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset(e)
|
||||
raise ConnectionResetError() from e
|
||||
else:
|
||||
raise
|
||||
|
||||
def read(self, size):
|
||||
@property
|
||||
def is_connected(self):
|
||||
"""Determines whether the client is connected or not."""
|
||||
return not self._closed.is_set()
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection."""
|
||||
try:
|
||||
if self._socket is not None:
|
||||
if self.is_connected:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
self._socket.close()
|
||||
except OSError:
|
||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||
finally:
|
||||
self._socket = None
|
||||
self._closed.set()
|
||||
|
||||
async def _wait_timeout_or_close(self, coro):
|
||||
"""
|
||||
Waits for the given coroutine to complete unless
|
||||
the socket is closed or `self.timeout` expires.
|
||||
"""
|
||||
done, running = await asyncio.wait(
|
||||
[coro, self._closed.wait()],
|
||||
timeout=self.timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
for r in running:
|
||||
r.cancel()
|
||||
if not self.is_connected:
|
||||
raise self.SocketClosed()
|
||||
if not done:
|
||||
raise asyncio.TimeoutError()
|
||||
return done.pop().result()
|
||||
|
||||
async def write(self, data):
|
||||
"""
|
||||
Writes (sends) the specified bytes to the connected peer.
|
||||
:param data: the data to send.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionResetError('Not connected')
|
||||
|
||||
try:
|
||||
await self._wait_timeout_or_close(self.sock_sendall(data))
|
||||
except OSError as e:
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
raise ConnectionResetError() from e
|
||||
else:
|
||||
raise
|
||||
|
||||
async def read(self, size):
|
||||
"""
|
||||
Reads (receives) a whole block of size bytes from the connected peer.
|
||||
|
||||
:param size: the size of the block to be read.
|
||||
:return: the read data with len(data) == size.
|
||||
"""
|
||||
if self._socket is None:
|
||||
self._raise_connection_reset(None)
|
||||
if not self.is_connected:
|
||||
raise ConnectionResetError('Not connected')
|
||||
|
||||
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
||||
with BytesIO() as buffer:
|
||||
bytes_left = size
|
||||
while bytes_left != 0:
|
||||
try:
|
||||
partial = self._socket.recv(bytes_left)
|
||||
except socket.timeout as e:
|
||||
# These are somewhat common if the server has nothing
|
||||
# to send to us, so use a lower logging priority.
|
||||
partial = await self._wait_timeout_or_close(
|
||||
self.sock_recv(bytes_left)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if bytes_left < size:
|
||||
__log__.warning(
|
||||
'socket.timeout "%s" when %d/%d had been received',
|
||||
e, size - bytes_left, size
|
||||
'Timeout when partial %d/%d had been received',
|
||||
size - bytes_left, size
|
||||
)
|
||||
else:
|
||||
__log__.debug(
|
||||
'socket.timeout "%s" while reading data', e
|
||||
)
|
||||
|
||||
raise TimeoutError() from e
|
||||
except ConnectionError as e:
|
||||
__log__.info('ConnectionError "%s" while reading data', e)
|
||||
self._raise_connection_reset(e)
|
||||
raise
|
||||
except OSError as e:
|
||||
if e.errno != errno.EBADF and self._closing_lock.locked():
|
||||
# Ignore bad file descriptor while closing
|
||||
__log__.info('OSError "%s" while reading data', e)
|
||||
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset(e)
|
||||
raise ConnectionResetError() from e
|
||||
else:
|
||||
raise
|
||||
|
||||
if len(partial) == 0:
|
||||
self._raise_connection_reset(None)
|
||||
if not partial:
|
||||
raise ConnectionResetError()
|
||||
|
||||
buffer.write(partial)
|
||||
bytes_left -= len(partial)
|
||||
|
||||
# If everything went fine, return the read bytes
|
||||
buffer.flush()
|
||||
return buffer.raw.getvalue()
|
||||
return buffer.getvalue()
|
||||
|
||||
def _raise_connection_reset(self, original):
|
||||
"""Disconnects the client and raises ConnectionResetError."""
|
||||
self.close() # Connection reset -> flag as socket closed
|
||||
raise ConnectionResetError('The server has closed the connection.')\
|
||||
from original
|
||||
# Due to recent https://github.com/python/cpython/pull/4386
|
||||
# Credit to @andr-04 for his original implementation
|
||||
def sock_recv(self, n):
|
||||
fut = self._loop.create_future()
|
||||
self._sock_recv(fut, None, n)
|
||||
return fut
|
||||
|
||||
def _sock_recv(self, fut, registered_fd, n):
|
||||
if registered_fd is not None:
|
||||
self._loop.remove_reader(registered_fd)
|
||||
if fut.cancelled() or self._socket is None:
|
||||
return
|
||||
|
||||
try:
|
||||
data = self._socket.recv(n)
|
||||
except (BlockingIOError, InterruptedError):
|
||||
fd = self._socket.fileno()
|
||||
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
fut.set_result(data)
|
||||
|
||||
def sock_sendall(self, data):
|
||||
fut = self._loop.create_future()
|
||||
if data:
|
||||
self._sock_sendall(fut, None, data)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
|
||||
def _sock_sendall(self, fut, registered_fd, data):
|
||||
if registered_fd:
|
||||
self._loop.remove_writer(registered_fd)
|
||||
if fut.cancelled() or self._socket is None:
|
||||
return
|
||||
|
||||
try:
|
||||
n = self._socket.send(data)
|
||||
except (BlockingIOError, InterruptedError):
|
||||
n = 0
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
return
|
||||
|
||||
if n == len(data):
|
||||
fut.set_result(None)
|
||||
else:
|
||||
if n:
|
||||
data = data[n:]
|
||||
fd = self._socket.fileno()
|
||||
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
"""Various helpers not related to the Telegram API itself"""
|
||||
import os
|
||||
import struct
|
||||
from hashlib import sha1, sha256
|
||||
|
||||
from telethon.crypto import AES
|
||||
from telethon.errors import SecurityError
|
||||
from telethon.extensions import BinaryReader
|
||||
|
||||
|
||||
# region Multiple utilities
|
||||
|
||||
|
@ -27,70 +22,6 @@ def ensure_parent_dir_exists(file_path):
|
|||
# region Cryptographic related utils
|
||||
|
||||
|
||||
def pack_message(session, message):
|
||||
"""Packs a message following MtProto 2.0 guidelines"""
|
||||
# See https://core.telegram.org/mtproto/description
|
||||
data = struct.pack('<qq', session.salt, session.id) + bytes(message)
|
||||
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
||||
|
||||
# Being substr(what, offset, length); x = 0 for client
|
||||
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
|
||||
msg_key_large = sha256(
|
||||
session.auth_key.key[88:88 + 32] + data + padding).digest()
|
||||
|
||||
# "msg_key = substr (msg_key_large, 8, 16)"
|
||||
msg_key = msg_key_large[8:24]
|
||||
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, True)
|
||||
|
||||
key_id = struct.pack('<Q', session.auth_key.key_id)
|
||||
return key_id + msg_key + AES.encrypt_ige(data + padding, aes_key, aes_iv)
|
||||
|
||||
|
||||
def unpack_message(session, reader):
|
||||
"""Unpacks a message following MtProto 2.0 guidelines"""
|
||||
# See https://core.telegram.org/mtproto/description
|
||||
if reader.read_long(signed=False) != session.auth_key.key_id:
|
||||
raise SecurityError('Server replied with an invalid auth key')
|
||||
|
||||
msg_key = reader.read(16)
|
||||
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, False)
|
||||
data = BinaryReader(AES.decrypt_ige(reader.read(), aes_key, aes_iv))
|
||||
|
||||
data.read_long() # remote_salt
|
||||
if data.read_long() != session.id:
|
||||
raise SecurityError('Server replied with a wrong session ID')
|
||||
|
||||
remote_msg_id = data.read_long()
|
||||
remote_sequence = data.read_int()
|
||||
msg_len = data.read_int()
|
||||
message = data.read(msg_len)
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines
|
||||
# Sections "checking sha256 hash" and "message length"
|
||||
if msg_key != sha256(
|
||||
session.auth_key.key[96:96 + 32] + data.get_bytes()).digest()[8:24]:
|
||||
raise SecurityError("Received msg_key doesn't match with expected one")
|
||||
|
||||
return message, remote_msg_id, remote_sequence
|
||||
|
||||
|
||||
def calc_key(auth_key, msg_key, client):
|
||||
"""
|
||||
Calculate the key based on Telegram guidelines
|
||||
for MtProto 2, specifying whether it's the client or not.
|
||||
"""
|
||||
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||
x = 0 if client else 8
|
||||
|
||||
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
|
||||
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
|
||||
|
||||
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
|
||||
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
|
||||
|
||||
return aes_key, aes_iv
|
||||
|
||||
|
||||
def generate_key_data_from_nonce(server_nonce, new_nonce):
|
||||
"""Generates the key data corresponding to the given nonce"""
|
||||
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
This module contains several classes regarding network, low level connection
|
||||
with Telegram's servers and the protocol used (TCP full, abridged, etc.).
|
||||
"""
|
||||
from .mtproto_plain_sender import MtProtoPlainSender
|
||||
from .mtprotoplainsender import MTProtoPlainSender
|
||||
from .authenticator import do_authentication
|
||||
from .mtproto_sender import MtProtoSender
|
||||
from .mtprotosender import MTProtoSender
|
||||
from .connection import (
|
||||
ConnectionTcpFull, ConnectionTcpAbridged, ConnectionTcpObfuscated,
|
||||
ConnectionTcpIntermediate
|
||||
|
|
|
@ -14,56 +14,24 @@ from .. import helpers as utils
|
|||
from ..crypto import AES, AuthKey, Factorization, rsa
|
||||
from ..errors import SecurityError
|
||||
from ..extensions import BinaryReader
|
||||
from ..network import MtProtoPlainSender
|
||||
from ..tl.functions import (
|
||||
ReqPqMultiRequest, ReqDHParamsRequest, SetClientDHParamsRequest
|
||||
)
|
||||
|
||||
|
||||
def do_authentication(connection, retries=5):
|
||||
"""
|
||||
Performs the authentication steps on the given connection.
|
||||
Raises an error if all attempts fail.
|
||||
|
||||
:param connection: the connection to be used (must be connected).
|
||||
:param retries: how many times should we retry on failure.
|
||||
:return:
|
||||
"""
|
||||
if not retries or retries < 0:
|
||||
retries = 1
|
||||
|
||||
last_error = None
|
||||
while retries:
|
||||
try:
|
||||
return _do_authentication(connection)
|
||||
except (SecurityError, AssertionError, NotImplementedError) as e:
|
||||
last_error = e
|
||||
retries -= 1
|
||||
raise last_error
|
||||
|
||||
|
||||
def _do_authentication(connection):
|
||||
async def do_authentication(sender):
|
||||
"""
|
||||
Executes the authentication process with the Telegram servers.
|
||||
|
||||
:param connection: the connection to be used (must be connected).
|
||||
:param sender: a connected `MTProtoPlainSender`.
|
||||
:return: returns a (authorization key, time offset) tuple.
|
||||
"""
|
||||
sender = MtProtoPlainSender(connection)
|
||||
|
||||
# Step 1 sending: PQ Request, endianness doesn't matter since it's random
|
||||
req_pq_request = ReqPqMultiRequest(
|
||||
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||
)
|
||||
sender.send(bytes(req_pq_request))
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
req_pq_request.on_response(reader)
|
||||
nonce = int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||
res_pq = await sender.send(ReqPqMultiRequest(nonce))
|
||||
assert isinstance(res_pq, ResPQ)
|
||||
|
||||
res_pq = req_pq_request.result
|
||||
if not isinstance(res_pq, ResPQ):
|
||||
raise AssertionError(res_pq)
|
||||
|
||||
if res_pq.nonce != req_pq_request.nonce:
|
||||
if res_pq.nonce != nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
|
||||
pq = get_int(res_pq.pq)
|
||||
|
@ -96,25 +64,15 @@ def _do_authentication(connection):
|
|||
)
|
||||
)
|
||||
|
||||
req_dh_params = ReqDHParamsRequest(
|
||||
server_dh_params = await sender.send(ReqDHParamsRequest(
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
p=p, q=q,
|
||||
public_key_fingerprint=target_fingerprint,
|
||||
encrypted_data=cipher_text
|
||||
)
|
||||
sender.send(bytes(req_dh_params))
|
||||
))
|
||||
|
||||
# Step 2 response: DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
req_dh_params.on_response(reader)
|
||||
|
||||
server_dh_params = req_dh_params.result
|
||||
if isinstance(server_dh_params, ServerDHParamsFail):
|
||||
raise SecurityError('Server DH params fail: TODO')
|
||||
|
||||
if not isinstance(server_dh_params, ServerDHParamsOk):
|
||||
raise AssertionError(server_dh_params)
|
||||
assert isinstance(server_dh_params, (ServerDHParamsOk, ServerDHParamsFail))
|
||||
|
||||
if server_dh_params.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
|
@ -122,6 +80,16 @@ def _do_authentication(connection):
|
|||
if server_dh_params.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
|
||||
if isinstance(server_dh_params, ServerDHParamsFail):
|
||||
nnh = int.from_bytes(
|
||||
sha1(new_nonce.to_bytes(32, 'little', signed=True)).digest()[4:20],
|
||||
'little', signed=True
|
||||
)
|
||||
if server_dh_params.new_nonce_hash != nnh:
|
||||
raise SecurityError('Invalid DH fail nonce from server')
|
||||
|
||||
assert isinstance(server_dh_params, ServerDHParamsOk)
|
||||
|
||||
# Step 3 sending: Complete DH Exchange
|
||||
key, iv = utils.generate_key_data_from_nonce(
|
||||
res_pq.server_nonce, new_nonce
|
||||
|
@ -137,8 +105,7 @@ def _do_authentication(connection):
|
|||
with BinaryReader(plain_text_answer) as reader:
|
||||
reader.read(20) # hash sum
|
||||
server_dh_inner = reader.tgread_object()
|
||||
if not isinstance(server_dh_inner, ServerDHInnerData):
|
||||
raise AssertionError(server_dh_inner)
|
||||
assert isinstance(server_dh_inner, ServerDHInnerData)
|
||||
|
||||
if server_dh_inner.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid nonce in encrypted answer')
|
||||
|
@ -168,43 +135,31 @@ def _do_authentication(connection):
|
|||
client_dh_encrypted = AES.encrypt_ige(client_dh_inner_hashed, key, iv)
|
||||
|
||||
# Prepare Set client DH params
|
||||
set_client_dh = SetClientDHParamsRequest(
|
||||
dh_gen = await sender.send(SetClientDHParamsRequest(
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
encrypted_data=client_dh_encrypted,
|
||||
)
|
||||
sender.send(bytes(set_client_dh))
|
||||
))
|
||||
|
||||
# Step 3 response: Complete DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
set_client_dh.on_response(reader)
|
||||
nonce_types = (DhGenOk, DhGenRetry, DhGenFail)
|
||||
assert isinstance(dh_gen, nonce_types)
|
||||
name = dh_gen.__class__.__name__
|
||||
if dh_gen.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid {} nonce from server'.format(name))
|
||||
|
||||
dh_gen = set_client_dh.result
|
||||
if isinstance(dh_gen, DhGenOk):
|
||||
if dh_gen.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
if dh_gen.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid {} server nonce from server'.format(name))
|
||||
|
||||
if dh_gen.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
auth_key = AuthKey(rsa.get_byte_array(gab))
|
||||
nonce_number = 1 + nonce_types.index(type(dh_gen))
|
||||
new_nonce_hash = auth_key.calc_new_nonce_hash(new_nonce, nonce_number)
|
||||
|
||||
auth_key = AuthKey(rsa.get_byte_array(gab))
|
||||
new_nonce_hash = int.from_bytes(
|
||||
auth_key.calc_new_nonce_hash(new_nonce, 1), 'little', signed=True
|
||||
)
|
||||
dh_hash = getattr(dh_gen, 'new_nonce_hash{}'.format(nonce_number))
|
||||
if dh_hash != new_nonce_hash:
|
||||
raise SecurityError('Invalid new nonce hash')
|
||||
|
||||
if dh_gen.new_nonce_hash1 != new_nonce_hash:
|
||||
raise SecurityError('Invalid new nonce hash')
|
||||
|
||||
return auth_key, time_offset
|
||||
|
||||
elif isinstance(dh_gen, DhGenRetry):
|
||||
raise NotImplementedError('DhGenRetry')
|
||||
|
||||
elif isinstance(dh_gen, DhGenFail):
|
||||
raise NotImplementedError('DhGenFail')
|
||||
|
||||
else:
|
||||
raise NotImplementedError('DH Gen unknown: {}'.format(dh_gen))
|
||||
assert isinstance(dh_gen, DhGenOk)
|
||||
return auth_key, time_offset
|
||||
|
||||
|
||||
def get_int(byte_array, signed=True):
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
"""
|
||||
This module holds the abstract `Connection` class.
|
||||
|
||||
The `Connection.send` and `Connection.recv` methods need **not** to be
|
||||
safe across several tasks and may use any amount of ``await`` keywords.
|
||||
|
||||
The code using these `Connection`'s should be responsible for using
|
||||
an ``async with asyncio.Lock:`` block when calling said methods.
|
||||
|
||||
Said subclasses need not to worry about reconnecting either, and
|
||||
should let the errors propagate instead.
|
||||
"""
|
||||
import abc
|
||||
from datetime import timedelta
|
||||
|
@ -23,7 +32,7 @@ class Connection(abc.ABC):
|
|||
self._timeout = timeout
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -41,7 +50,7 @@ class Connection(abc.ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self):
|
||||
async def close(self):
|
||||
"""Closes the connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -51,11 +60,11 @@ class Connection(abc.ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def recv(self):
|
||||
async def recv(self):
|
||||
"""Receives and unpacks a message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, message):
|
||||
async def send(self, message):
|
||||
"""Encapsulates and sends the given message"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -20,9 +20,9 @@ class ConnectionTcpFull(Connection):
|
|||
self.read = self.conn.read
|
||||
self.write = self.conn.write
|
||||
|
||||
def connect(self, ip, port):
|
||||
async def connect(self, ip, port):
|
||||
try:
|
||||
self.conn.connect(ip, port)
|
||||
await self.conn.connect(ip, port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
|
@ -35,19 +35,20 @@ class ConnectionTcpFull(Connection):
|
|||
return self.conn.timeout
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.connected
|
||||
return self.conn.is_connected
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def clone(self):
|
||||
return ConnectionTcpFull(self._proxy, self._timeout)
|
||||
|
||||
def recv(self):
|
||||
packet_len_seq = self.read(8) # 4 and 4
|
||||
async def recv(self):
|
||||
packet_len_seq = await self.read(8) # 4 and 4
|
||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
||||
body = self.read(packet_len - 12)
|
||||
checksum = struct.unpack('<I', self.read(4))[0]
|
||||
body = await self.read(packet_len - 8)
|
||||
checksum = struct.unpack('<I', body[-4:])[0]
|
||||
body = body[:-4]
|
||||
|
||||
valid_checksum = crc32(packet_len_seq + body)
|
||||
if checksum != valid_checksum:
|
||||
|
@ -55,11 +56,11 @@ class ConnectionTcpFull(Connection):
|
|||
|
||||
return body
|
||||
|
||||
def send(self, message):
|
||||
async def send(self, message):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(message) + 12
|
||||
data = struct.pack('<ii', length, self._send_counter) + message
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
self.write(data + crc)
|
||||
await self.write(data + crc)
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
"""
|
||||
This module contains the class used to communicate with Telegram's servers
|
||||
in plain text, when no authorization key has been created yet.
|
||||
"""
|
||||
import struct
|
||||
import time
|
||||
|
||||
from ..errors import BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
|
||||
|
||||
class MtProtoPlainSender:
|
||||
"""
|
||||
MTProto Mobile Protocol plain sender
|
||||
(https://core.telegram.org/mtproto/description#unencrypted-messages)
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
"""
|
||||
Initializes the MTProto plain sender.
|
||||
|
||||
:param connection: the Connection to be used.
|
||||
"""
|
||||
self._sequence = 0
|
||||
self._time_offset = 0
|
||||
self._last_msg_id = 0
|
||||
self._connection = connection
|
||||
|
||||
def connect(self):
|
||||
"""Connects to Telegram's servers."""
|
||||
self._connection.connect()
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from Telegram's servers."""
|
||||
self._connection.close()
|
||||
|
||||
def send(self, data):
|
||||
"""
|
||||
Sends a plain packet (auth_key_id = 0) containing the
|
||||
given message body (data).
|
||||
|
||||
:param data: the data to be sent.
|
||||
"""
|
||||
self._connection.send(
|
||||
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
||||
)
|
||||
|
||||
def receive(self):
|
||||
"""
|
||||
Receives a plain packet from the network.
|
||||
|
||||
:return: the response body.
|
||||
"""
|
||||
body = self._connection.recv()
|
||||
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
||||
# Broken authorization, must reset the auth key
|
||||
raise BrokenAuthKeyError()
|
||||
|
||||
with BinaryReader(body) as reader:
|
||||
reader.read_long() # auth_key_id
|
||||
reader.read_long() # msg_id
|
||||
message_length = reader.read_int()
|
||||
|
||||
response = reader.read(message_length)
|
||||
return response
|
||||
|
||||
def _get_new_msg_id(self):
|
||||
"""Generates a new message ID based on the current time since epoch."""
|
||||
# See core.telegram.org/mtproto/description#message-identifier-msg-id
|
||||
now = time.time()
|
||||
nanoseconds = int((now - int(now)) * 1e+9)
|
||||
# "message identifiers are divisible by 4"
|
||||
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
|
||||
if self._last_msg_id >= new_msg_id:
|
||||
new_msg_id = self._last_msg_id + 4
|
||||
|
||||
self._last_msg_id = new_msg_id
|
||||
return new_msg_id
|
|
@ -1,590 +0,0 @@
|
|||
"""
|
||||
This module contains the class used to communicate with Telegram's servers
|
||||
encrypting every packet, and relies on a valid AuthKey in the used Session.
|
||||
"""
|
||||
import logging
|
||||
from threading import Lock
|
||||
|
||||
from .. import helpers, utils
|
||||
from ..errors import (
|
||||
BadMessageError, InvalidChecksumError, BrokenAuthKeyError,
|
||||
rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl import TLMessage, MessageContainer, GzipPacked
|
||||
from ..tl.all_tlobjects import tlobjects
|
||||
from ..tl.functions import InvokeAfterMsgRequest
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
from ..tl.types import (
|
||||
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
|
||||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
|
||||
)
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MtProtoSender:
|
||||
"""
|
||||
MTProto Mobile Protocol sender
|
||||
(https://core.telegram.org/mtproto/description).
|
||||
|
||||
Note that this class is not thread-safe, and calling send/receive
|
||||
from two or more threads at the same time is undefined behaviour.
|
||||
Rationale:
|
||||
a new connection should be spawned to send/receive requests
|
||||
in parallel, so thread-safety (hence locking) isn't needed.
|
||||
"""
|
||||
|
||||
def __init__(self, session, connection):
|
||||
"""
|
||||
Initializes a new MTProto sender.
|
||||
|
||||
:param session:
|
||||
the Session to be used with this sender. Must contain the IP and
|
||||
port of the server, salt, ID, and AuthKey,
|
||||
:param connection:
|
||||
the Connection to be used.
|
||||
"""
|
||||
self.session = session
|
||||
self.connection = connection
|
||||
|
||||
# Message IDs that need confirmation
|
||||
self._need_confirmation = set()
|
||||
|
||||
# Requests (as msg_id: Message) sent waiting to be received
|
||||
self._pending_receive = {}
|
||||
|
||||
# Multithreading
|
||||
self._send_lock = Lock()
|
||||
|
||||
# If we're invoking something from an update thread but we're also
|
||||
# receiving other request from the main thread (e.g. an update arrives
|
||||
# and we need to process it) we must ensure that only one is calling
|
||||
# receive at a given moment, since the receive step is fragile.
|
||||
self._recv_lock = Lock()
|
||||
|
||||
def connect(self):
|
||||
"""Connects to the server."""
|
||||
self.connection.connect(self.session.server_address, self.session.port)
|
||||
|
||||
def is_connected(self):
|
||||
"""
|
||||
Determines whether the sender is connected or not.
|
||||
|
||||
:return: true if the sender is connected.
|
||||
"""
|
||||
return self.connection.is_connected()
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from the server."""
|
||||
__log__.info('Disconnecting MtProtoSender...')
|
||||
self.connection.close()
|
||||
self._clear_all_pending()
|
||||
|
||||
# region Send and receive
|
||||
|
||||
def send(self, requests, ordered=False):
|
||||
"""
|
||||
Sends the specified TLObject(s) (which must be requests),
|
||||
and acknowledging any message which needed confirmation.
|
||||
|
||||
:param requests: the requests to be sent.
|
||||
:param ordered: whether the requests should be invoked in the
|
||||
order in which they appear or they can be executed
|
||||
in arbitrary order in the server.
|
||||
"""
|
||||
if not utils.is_list_like(requests):
|
||||
requests = (requests,)
|
||||
|
||||
if ordered:
|
||||
requests = iter(requests)
|
||||
messages = [TLMessage(self.session, next(requests))]
|
||||
for r in requests:
|
||||
messages.append(TLMessage(self.session, r,
|
||||
after_id=messages[-1].msg_id))
|
||||
else:
|
||||
messages = [TLMessage(self.session, r) for r in requests]
|
||||
|
||||
self._pending_receive.update({m.msg_id: m for m in messages})
|
||||
|
||||
__log__.debug('Sending requests with IDs: %s', ', '.join(
|
||||
'{}: {}'.format(m.request.__class__.__name__, m.msg_id)
|
||||
for m in messages
|
||||
))
|
||||
|
||||
# Pack everything in the same container if we need to send AckRequests
|
||||
if self._need_confirmation:
|
||||
messages.append(
|
||||
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
|
||||
)
|
||||
self._need_confirmation.clear()
|
||||
|
||||
if len(messages) == 1:
|
||||
message = messages[0]
|
||||
else:
|
||||
message = TLMessage(self.session, MessageContainer(messages))
|
||||
# On bad_msg_salt errors, Telegram will reply with the ID of
|
||||
# the container and not the requests it contains, so in case
|
||||
# this happens we need to know to which container they belong.
|
||||
for m in messages:
|
||||
m.container_msg_id = message.msg_id
|
||||
|
||||
self._send_message(message)
|
||||
|
||||
def _send_acknowledge(self, msg_id):
|
||||
"""Sends a message acknowledge for the given msg_id."""
|
||||
self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||
|
||||
def receive(self, update_state):
|
||||
"""
|
||||
Receives a single message from the connected endpoint.
|
||||
|
||||
This method returns nothing, and will only affect other parts
|
||||
of the MtProtoSender such as the updates callback being fired
|
||||
or a pending request being confirmed.
|
||||
|
||||
Any unhandled object (likely updates) will be passed to
|
||||
update_state.process(TLObject).
|
||||
|
||||
:param update_state:
|
||||
the UpdateState that will process all the received
|
||||
Update and Updates objects.
|
||||
"""
|
||||
if self._recv_lock.locked():
|
||||
with self._recv_lock:
|
||||
# Don't busy wait, acquire it but return because there's
|
||||
# already a receive running and we don't want another one.
|
||||
# It would lock until Telegram sent another update even if
|
||||
# the current receive already received the expected response.
|
||||
return
|
||||
|
||||
try:
|
||||
with self._recv_lock:
|
||||
body = self.connection.recv()
|
||||
except (BufferError, InvalidChecksumError):
|
||||
# TODO BufferError, we should spot the cause...
|
||||
# "No more bytes left"; something wrong happened, clear
|
||||
# everything to be on the safe side, or:
|
||||
#
|
||||
# "This packet should be skipped"; since this may have
|
||||
# been a result for a request, invalidate every request
|
||||
# and just re-invoke them to avoid problems
|
||||
__log__.exception('Error while receiving server response. '
|
||||
'%d pending request(s) will be ignored',
|
||||
len(self._pending_receive))
|
||||
self._clear_all_pending()
|
||||
return
|
||||
|
||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
with BinaryReader(message) as reader:
|
||||
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Low level processing
|
||||
|
||||
def _send_message(self, message):
|
||||
"""
|
||||
Sends the given encrypted through the network.
|
||||
|
||||
:param message: the TLMessage to be sent.
|
||||
"""
|
||||
with self._send_lock:
|
||||
self.connection.send(helpers.pack_message(self.session, message))
|
||||
|
||||
def _decode_msg(self, body):
|
||||
"""
|
||||
Decodes the body of the payload received from the network.
|
||||
|
||||
:param body: the body to be decoded.
|
||||
:return: a tuple of (decoded message, remote message id, remote seq).
|
||||
"""
|
||||
if len(body) < 8:
|
||||
if body == b'l\xfe\xff\xff':
|
||||
raise BrokenAuthKeyError()
|
||||
else:
|
||||
raise BufferError("Can't decode packet ({})".format(body))
|
||||
|
||||
with BinaryReader(body) as reader:
|
||||
return helpers.unpack_message(self.session, reader)
|
||||
|
||||
def _process_msg(self, msg_id, sequence, reader, state):
|
||||
"""
|
||||
Processes the message read from the network inside reader.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the BinaryReader that contains the message.
|
||||
:param state: the current UpdateState.
|
||||
:return: true if the message was handled correctly, false otherwise.
|
||||
"""
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
self._need_confirmation.add(msg_id)
|
||||
|
||||
code = reader.read_int(signed=False)
|
||||
reader.seek(-4)
|
||||
|
||||
# These are a bit of special case, not yet generated by the code gen
|
||||
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
|
||||
__log__.debug('Processing Remote Procedure Call result')
|
||||
return self._handle_rpc_result(msg_id, sequence, reader)
|
||||
|
||||
if code == MessageContainer.CONSTRUCTOR_ID:
|
||||
__log__.debug('Processing container result')
|
||||
return self._handle_container(msg_id, sequence, reader, state)
|
||||
|
||||
if code == GzipPacked.CONSTRUCTOR_ID:
|
||||
__log__.debug('Processing gzipped result')
|
||||
return self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
|
||||
if code not in tlobjects:
|
||||
__log__.warning(
|
||||
'Unknown message with ID %d, data left in the buffer %s',
|
||||
hex(code), repr(reader.get_bytes()[reader.tell_position():])
|
||||
)
|
||||
return False
|
||||
|
||||
obj = reader.tgread_object()
|
||||
__log__.debug('Processing %s result', type(obj).__name__)
|
||||
|
||||
if isinstance(obj, Pong):
|
||||
return self._handle_pong(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, BadServerSalt):
|
||||
return self._handle_bad_server_salt(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, BadMsgNotification):
|
||||
return self._handle_bad_msg_notification(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, MsgDetailedInfo):
|
||||
return self._handle_msg_detailed_info(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, MsgNewDetailedInfo):
|
||||
return self._handle_msg_new_detailed_info(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, NewSessionCreated):
|
||||
return self._handle_new_session_created(msg_id, sequence, obj)
|
||||
|
||||
if isinstance(obj, MsgsAck): # may handle the request we wanted
|
||||
# Ignore every ack request *unless* when logging out, when it's
|
||||
# when it seems to only make sense. We also need to set a non-None
|
||||
# result since Telegram doesn't send the response for these.
|
||||
for msg_id in obj.msg_ids:
|
||||
r = self._pop_request_of_type(msg_id, LogOutRequest)
|
||||
if r:
|
||||
r.result = True # Telegram won't send this value
|
||||
r.confirm_received.set()
|
||||
__log__.debug('Confirmed %s through ack', type(r).__name__)
|
||||
|
||||
return True
|
||||
|
||||
if isinstance(obj, FutureSalts):
|
||||
r = self._pop_request(obj.req_msg_id)
|
||||
if r:
|
||||
r.result = obj
|
||||
r.confirm_received.set()
|
||||
__log__.debug('Confirmed %s through salt', type(r).__name__)
|
||||
|
||||
# If the object isn't any of the above, then it should be an Update.
|
||||
self.session.process_entities(obj)
|
||||
if state:
|
||||
state.process(obj)
|
||||
|
||||
return True
|
||||
|
||||
# endregion
|
||||
|
||||
# region Message handling
|
||||
|
||||
def _pop_request(self, msg_id):
|
||||
"""
|
||||
Pops a pending **request** from self._pending_receive.
|
||||
|
||||
:param msg_id: the ID of the message that belongs to the request.
|
||||
:return: the request, or None if it wasn't found.
|
||||
"""
|
||||
message = self._pending_receive.pop(msg_id, None)
|
||||
if message:
|
||||
return message.request
|
||||
|
||||
def _pop_request_of_type(self, msg_id, t):
|
||||
"""
|
||||
Pops a pending **request** from self._pending_receive.
|
||||
|
||||
:param msg_id: the ID of the message that belongs to the request.
|
||||
:param t: the type of the desired request.
|
||||
:return: the request matching the type t, or None if it wasn't found.
|
||||
"""
|
||||
message = self._pending_receive.get(msg_id, None)
|
||||
if message and isinstance(message.request, t):
|
||||
return self._pending_receive.pop(msg_id).request
|
||||
|
||||
def _pop_requests_of_container(self, container_msg_id):
|
||||
"""
|
||||
Pops pending **requests** from self._pending_receive.
|
||||
|
||||
:param container_msg_id: the ID of the container.
|
||||
:return: the requests that belong to the given container. May be empty.
|
||||
"""
|
||||
msgs = [msg for msg in self._pending_receive.values()
|
||||
if msg.container_msg_id == container_msg_id]
|
||||
|
||||
requests = [msg.request for msg in msgs]
|
||||
for msg in msgs:
|
||||
self._pending_receive.pop(msg.msg_id, None)
|
||||
return requests
|
||||
|
||||
def _clear_all_pending(self):
|
||||
"""
|
||||
Clears all pending requests, and flags them all as received.
|
||||
"""
|
||||
for r in self._pending_receive.values():
|
||||
r.request.confirm_received.set()
|
||||
__log__.info('Abruptly confirming %s', type(r).__name__)
|
||||
self._pending_receive.clear()
|
||||
|
||||
def _resend_request(self, msg_id):
|
||||
"""
|
||||
Re-sends the request that belongs to a certain msg_id. This may
|
||||
also be the msg_id of a container if they were sent in one.
|
||||
|
||||
:param msg_id: the ID of the request to be resent.
|
||||
"""
|
||||
request = self._pop_request(msg_id)
|
||||
if request:
|
||||
return self.send(request)
|
||||
requests = self._pop_requests_of_container(msg_id)
|
||||
if requests:
|
||||
return self.send(*requests)
|
||||
|
||||
def _handle_pong(self, msg_id, sequence, pong):
|
||||
"""
|
||||
Handles a Pong response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the Pong.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
request = self._pop_request(pong.msg_id)
|
||||
if request:
|
||||
request.result = pong
|
||||
request.confirm_received.set()
|
||||
__log__.debug('Confirmed %s through pong', type(request).__name__)
|
||||
|
||||
return True
|
||||
|
||||
def _handle_container(self, msg_id, sequence, reader, state):
|
||||
"""
|
||||
Handles a MessageContainer response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the MessageContainer.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
||||
begin_position = reader.tell_position()
|
||||
|
||||
# Note that this code is IMPORTANT for skipping RPC results of
|
||||
# lost requests (i.e., ones from the previous connection session)
|
||||
try:
|
||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
reader.set_position(begin_position + inner_len)
|
||||
except:
|
||||
# If any error is raised, something went wrong; skip the packet
|
||||
reader.set_position(begin_position + inner_len)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
def _handle_bad_server_salt(self, msg_id, sequence, bad_salt):
|
||||
"""
|
||||
Handles a BadServerSalt response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the BadServerSalt.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
self.session.salt = bad_salt.new_server_salt
|
||||
self.session.save()
|
||||
|
||||
# "the bad_server_salt response is received with the
|
||||
# correct salt, and the message is to be re-sent with it"
|
||||
self._resend_request(bad_salt.bad_msg_id)
|
||||
return True
|
||||
|
||||
def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
|
||||
"""
|
||||
Handles a BadMessageError response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the BadMessageError.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
error = BadMessageError(bad_msg.error_code)
|
||||
__log__.warning('Read bad msg notification %s: %s', bad_msg, error)
|
||||
if bad_msg.error_code in (16, 17):
|
||||
# sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
self.session.update_time_offset(correct_msg_id=msg_id)
|
||||
__log__.info('Attempting to use the correct time offset')
|
||||
self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
elif bad_msg.error_code == 32:
|
||||
# msg_seqno too low, so just pump it up by some "large" amount
|
||||
# TODO A better fix would be to start with a new fresh session ID
|
||||
self.session.sequence += 64
|
||||
__log__.info('Attempting to set the right higher sequence')
|
||||
self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
elif bad_msg.error_code == 33:
|
||||
# msg_seqno too high never seems to happen but just in case
|
||||
self.session.sequence -= 16
|
||||
__log__.info('Attempting to set the right lower sequence')
|
||||
self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
else:
|
||||
raise error
|
||||
|
||||
def _handle_msg_detailed_info(self, msg_id, sequence, msg_new):
|
||||
"""
|
||||
Handles a MsgDetailedInfo response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the MsgDetailedInfo.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
# TODO For now, simply ack msg_new.answer_msg_id
|
||||
# Relevant tdesktop source code: https://goo.gl/VvpCC6
|
||||
self._send_acknowledge(msg_new.answer_msg_id)
|
||||
return True
|
||||
|
||||
def _handle_msg_new_detailed_info(self, msg_id, sequence, msg_new):
|
||||
"""
|
||||
Handles a MsgNewDetailedInfo response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the MsgNewDetailedInfo.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
# TODO For now, simply ack msg_new.answer_msg_id
|
||||
# Relevant tdesktop source code: https://goo.gl/G7DPsR
|
||||
self._send_acknowledge(msg_new.answer_msg_id)
|
||||
return True
|
||||
|
||||
def _handle_new_session_created(self, msg_id, sequence, new_session):
|
||||
"""
|
||||
Handles a NewSessionCreated response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the NewSessionCreated.
|
||||
:return: true, as it always succeeds.
|
||||
"""
|
||||
self.session.salt = new_session.server_salt
|
||||
# TODO https://goo.gl/LMyN7A
|
||||
return True
|
||||
|
||||
def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||
"""
|
||||
Handles a RPCResult response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the RPCResult.
|
||||
:return: true if the request ID to which this result belongs is found,
|
||||
false otherwise (meaning nothing was read).
|
||||
"""
|
||||
reader.read_int(signed=False) # code
|
||||
request_id = reader.read_long()
|
||||
inner_code = reader.read_int(signed=False)
|
||||
reader.seek(-4)
|
||||
|
||||
__log__.debug('Received response for request with ID %d', request_id)
|
||||
request = self._pop_request(request_id)
|
||||
|
||||
if inner_code == 0x2144ca19: # RPC Error
|
||||
reader.seek(4)
|
||||
if self.session.report_errors and request:
|
||||
error = rpc_message_to_error(
|
||||
reader.read_int(), reader.tgread_string(),
|
||||
report_method=type(request).CONSTRUCTOR_ID
|
||||
)
|
||||
else:
|
||||
error = rpc_message_to_error(
|
||||
reader.read_int(), reader.tgread_string()
|
||||
)
|
||||
|
||||
# Acknowledge that we received the error
|
||||
self._send_acknowledge(request_id)
|
||||
|
||||
if request:
|
||||
request.rpc_error = error
|
||||
request.confirm_received.set()
|
||||
|
||||
__log__.debug('Confirmed %s through error %s',
|
||||
type(request).__name__, error)
|
||||
# else TODO Where should this error be reported?
|
||||
# Read may be async. Can an error not-belong to a request?
|
||||
return True # All contents were read okay
|
||||
|
||||
elif request:
|
||||
if inner_code == GzipPacked.CONSTRUCTOR_ID:
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
request.on_response(compressed_reader)
|
||||
else:
|
||||
request.on_response(reader)
|
||||
|
||||
self.session.process_entities(request.result)
|
||||
request.confirm_received.set()
|
||||
__log__.debug(
|
||||
'Confirmed %s through normal result %s',
|
||||
type(request).__name__, type(request.result).__name__
|
||||
)
|
||||
return True
|
||||
|
||||
# If it's really a result for RPC from previous connection
|
||||
# session, it will be skipped by the handle_container().
|
||||
# For some reason this also seems to happen when downloading
|
||||
# photos, where the server responds with FileJpeg().
|
||||
def _try_read(r):
|
||||
try:
|
||||
return r.tgread_object()
|
||||
except Exception as e:
|
||||
return '(failed to read: {})'.format(e)
|
||||
|
||||
if inner_code == GzipPacked.CONSTRUCTOR_ID:
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
obj = _try_read(compressed_reader)
|
||||
else:
|
||||
obj = _try_read(reader)
|
||||
|
||||
__log__.warning(
|
||||
'Lost request (ID %d) with code %s will be skipped, contents: %s',
|
||||
request_id, hex(inner_code), obj
|
||||
)
|
||||
return False
|
||||
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
"""
|
||||
Handles a GzipPacked response.
|
||||
|
||||
:param msg_id: the ID of the message.
|
||||
:param sequence: the sequence of the message.
|
||||
:param reader: the reader containing the GzipPacked.
|
||||
:return: the result of processing the packed message.
|
||||
"""
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
# We are reentering process_msg, which seemingly the same msg_id
|
||||
# to the self._need_confirmation set. Remove it from there first
|
||||
# to avoid any future conflicts (i.e. if we "ignore" messages
|
||||
# that we are already aware of, see 1a91c02 and old 63dfb1e)
|
||||
self._need_confirmation -= {msg_id}
|
||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
46
telethon/network/mtprotoplainsender.py
Normal file
46
telethon/network/mtprotoplainsender.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""
|
||||
This module contains the class used to communicate with Telegram's servers
|
||||
in plain text, when no authorization key has been created yet.
|
||||
"""
|
||||
import struct
|
||||
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..errors import BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
|
||||
|
||||
class MTProtoPlainSender:
|
||||
"""
|
||||
MTProto Mobile Protocol plain sender
|
||||
(https://core.telegram.org/mtproto/description#unencrypted-messages)
|
||||
"""
|
||||
def __init__(self, connection):
|
||||
"""
|
||||
Initializes the MTProto plain sender.
|
||||
|
||||
:param connection: the Connection to be used.
|
||||
"""
|
||||
self._state = MTProtoState(auth_key=None)
|
||||
self._connection = connection
|
||||
|
||||
async def send(self, request):
|
||||
"""
|
||||
Sends and receives the result for the given request.
|
||||
"""
|
||||
body = bytes(request)
|
||||
msg_id = self._state._get_new_msg_id()
|
||||
await self._connection.send(
|
||||
struct.pack('<QQi', 0, msg_id, len(body)) + body
|
||||
)
|
||||
|
||||
body = await self._connection.recv()
|
||||
if body == b'l\xfe\xff\xff': # -404 little endian signed
|
||||
# Broken authorization, must reset the auth key
|
||||
raise BrokenAuthKeyError()
|
||||
|
||||
with BinaryReader(body) as reader:
|
||||
assert reader.read_long() == 0 # auth_key_id
|
||||
assert reader.read_long() > msg_id # msg_id
|
||||
assert reader.read_int() # length
|
||||
# No need to read "length" bytes first, just read the object
|
||||
return reader.tgread_object()
|
633
telethon/network/mtprotosender.py
Normal file
633
telethon/network/mtprotosender.py
Normal file
|
@ -0,0 +1,633 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from . import MTProtoPlainSender, authenticator
|
||||
from .. import utils
|
||||
from ..errors import (
|
||||
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
|
||||
rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
from ..tl.types import (
|
||||
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
|
||||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
||||
MsgsStateInfo, MsgsAllInfo, MsgResendReq
|
||||
)
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO Create some kind of "ReconnectionPolicy" that allows specifying
|
||||
# what should be done in case of some errors, with some sane defaults.
|
||||
# For instance, should all messages be set with an error upon network
|
||||
# loss? Should we try reconnecting forever? A certain amount of times?
|
||||
# A timeout? What about recoverable errors, like connection reset?
|
||||
class MTProtoSender:
|
||||
"""
|
||||
MTProto Mobile Protocol sender
|
||||
(https://core.telegram.org/mtproto/description).
|
||||
|
||||
This class is responsible for wrapping requests into `TLMessage`'s,
|
||||
sending them over the network and receiving them in a safe manner.
|
||||
|
||||
Automatic reconnection due to temporary network issues is a concern
|
||||
for this class as well, including retry of messages that could not
|
||||
be sent successfully.
|
||||
|
||||
A new authorization key will be generated on connection if no other
|
||||
key exists yet.
|
||||
"""
|
||||
def __init__(self, state, connection, *, retries=5,
|
||||
first_query=None, update_callback=None):
|
||||
self.state = state
|
||||
self._connection = connection
|
||||
self._ip = None
|
||||
self._port = None
|
||||
self._retries = retries
|
||||
self._first_query = first_query
|
||||
self._is_first_query = bool(first_query)
|
||||
self._update_callback = update_callback
|
||||
|
||||
# Whether the user has explicitly connected or disconnected.
|
||||
#
|
||||
# If a disconnection happens for any other reason and it
|
||||
# was *not* user action then the pending messages won't
|
||||
# be cleared but on explicit user disconnection all the
|
||||
# pending futures should be cancelled.
|
||||
self._user_connected = False
|
||||
self._reconnecting = False
|
||||
|
||||
# We need to join the loops upon disconnection
|
||||
self._send_loop_handle = None
|
||||
self._recv_loop_handle = None
|
||||
|
||||
# Sending something shouldn't block
|
||||
self._send_queue = _ContainerQueue()
|
||||
|
||||
# Telegram responds to messages out of order. Keep
|
||||
# {id: Message} to set their Future result upon arrival.
|
||||
self._pending_messages = {}
|
||||
|
||||
# Containers are accepted or rejected as a whole when any of
|
||||
# its inner requests are acknowledged. For this purpose we
|
||||
# all the sent containers here.
|
||||
self._pending_containers = []
|
||||
|
||||
# We need to acknowledge every response from Telegram
|
||||
self._pending_ack = set()
|
||||
|
||||
# Jump table from response ID to method that handles it
|
||||
self._handlers = {
|
||||
RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result,
|
||||
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
|
||||
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
|
||||
Pong.CONSTRUCTOR_ID: self._handle_pong,
|
||||
BadServerSalt.CONSTRUCTOR_ID: self._handle_bad_server_salt,
|
||||
BadMsgNotification.CONSTRUCTOR_ID: self._handle_bad_notification,
|
||||
MsgDetailedInfo.CONSTRUCTOR_ID: self._handle_detailed_info,
|
||||
MsgNewDetailedInfo.CONSTRUCTOR_ID: self._handle_new_detailed_info,
|
||||
NewSessionCreated.CONSTRUCTOR_ID: self._handle_new_session_created,
|
||||
MsgsAck.CONSTRUCTOR_ID: self._handle_ack,
|
||||
FutureSalts.CONSTRUCTOR_ID: self._handle_future_salts,
|
||||
MsgsStateReq.CONSTRUCTOR_ID: self._handle_state_forgotten,
|
||||
MsgResendReq.CONSTRUCTOR_ID: self._handle_state_forgotten,
|
||||
MsgsAllInfo.CONSTRUCTOR_ID: self._handle_msg_all,
|
||||
}
|
||||
|
||||
# Public API
|
||||
|
||||
async def connect(self, ip, port):
|
||||
"""
|
||||
Connects to the specified ``ip:port``, and generates a new
|
||||
authorization key for the `MTProtoSender.session` if it does
|
||||
not exist yet.
|
||||
"""
|
||||
if self._user_connected:
|
||||
__log__.info('User is already connected!')
|
||||
return
|
||||
|
||||
self._ip = ip
|
||||
self._port = port
|
||||
self._user_connected = True
|
||||
await self._connect()
|
||||
|
||||
def is_connected(self):
|
||||
return self._user_connected
|
||||
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Cleanly disconnects the instance from the network, cancels
|
||||
all pending requests, and closes the send and receive loops.
|
||||
"""
|
||||
if not self._user_connected:
|
||||
__log__.info('User is already disconnected!')
|
||||
return
|
||||
|
||||
__log__.info('Disconnecting from {}...'.format(self._ip))
|
||||
self._user_connected = False
|
||||
try:
|
||||
__log__.debug('Closing current connection...')
|
||||
await self._connection.close()
|
||||
finally:
|
||||
__log__.debug('Cancelling {} pending message(s)...'
|
||||
.format(len(self._pending_messages)))
|
||||
for message in self._pending_messages.values():
|
||||
message.future.cancel()
|
||||
|
||||
self._pending_messages.clear()
|
||||
self._pending_ack.clear()
|
||||
|
||||
__log__.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
__log__.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
|
||||
__log__.info('Disconnection from {} complete!'.format(self._ip))
|
||||
|
||||
def send(self, request, ordered=False):
|
||||
"""
|
||||
This method enqueues the given request to be sent.
|
||||
|
||||
The request will be wrapped inside a `TLMessage` until its
|
||||
response arrives, and the `Future` response of the `TLMessage`
|
||||
is immediately returned so that one can further ``await`` it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def method():
|
||||
# Sending (enqueued for the send loop)
|
||||
future = sender.send(request)
|
||||
# Receiving (waits for the receive loop to read the result)
|
||||
result = await future
|
||||
|
||||
Designed like this because Telegram may send the response at
|
||||
any point, and it can send other items while one waits for it.
|
||||
Once the response for this future arrives, it is set with the
|
||||
received result, quite similar to how a ``receive()`` call
|
||||
would otherwise work.
|
||||
|
||||
Since the receiving part is "built in" the future, it's
|
||||
impossible to await receive a result that was never sent.
|
||||
"""
|
||||
if utils.is_list_like(request):
|
||||
result = []
|
||||
after = None
|
||||
for r in request:
|
||||
message = self.state.create_message(r, after=after)
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._send_queue.put_nowait(message)
|
||||
result.append(message.future)
|
||||
after = ordered and message
|
||||
return result
|
||||
else:
|
||||
message = self.state.create_message(request)
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._send_queue.put_nowait(message)
|
||||
return message.future
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _connect(self):
|
||||
"""
|
||||
Performs the actual connection, retrying, generating the
|
||||
authorization key if necessary, and starting the send and
|
||||
receive loops.
|
||||
"""
|
||||
__log__.info('Connecting to {}:{}...'.format(self._ip, self._port))
|
||||
_last_error = ConnectionError()
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('Connection attempt {}...'.format(retry))
|
||||
await self._connection.connect(self._ip, self._port)
|
||||
except OSError as e:
|
||||
_last_error = e
|
||||
__log__.warning('Attempt {} at connecting failed: {}'
|
||||
.format(retry, e))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise _last_error
|
||||
|
||||
__log__.debug('Connection success!')
|
||||
if self.state.auth_key is None:
|
||||
self._is_first_query = bool(self._first_query)
|
||||
_last_error = SecurityError()
|
||||
plain = MTProtoPlainSender(self._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('New auth_key attempt {}...'.format(retry))
|
||||
self.state.auth_key, self.state.time_offset =\
|
||||
await authenticator.do_authentication(plain)
|
||||
except (SecurityError, AssertionError) as e:
|
||||
_last_error = e
|
||||
__log__.warning('Attempt {} at new auth_key failed: {}'
|
||||
.format(retry, e))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise _last_error
|
||||
|
||||
__log__.debug('Starting send loop')
|
||||
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
|
||||
__log__.debug('Starting receive loop')
|
||||
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
|
||||
if self._is_first_query:
|
||||
__log__.debug('Running first query')
|
||||
self._is_first_query = False
|
||||
await self.send(self._first_query)
|
||||
|
||||
__log__.info('Connection to {} complete!'.format(self._ip))
|
||||
|
||||
async def _reconnect(self):
|
||||
"""
|
||||
Cleanly disconnects and then reconnects.
|
||||
"""
|
||||
self._reconnecting = True
|
||||
|
||||
__log__.debug('Awaiting for the send loop before reconnecting...')
|
||||
await self._send_loop_handle
|
||||
|
||||
__log__.debug('Awaiting for the receive loop before reconnecting...')
|
||||
await self._recv_loop_handle
|
||||
|
||||
__log__.debug('Closing current connection...')
|
||||
await self._connection.close()
|
||||
|
||||
self._reconnecting = False
|
||||
await self._connect()
|
||||
|
||||
def _clean_containers(self, msg_ids):
|
||||
"""
|
||||
Helper method to clean containers from the pending messages
|
||||
once a wrapped msg_id of them has been acknowledged.
|
||||
|
||||
This is the only way we can resend TLMessage(MessageContainer)
|
||||
on bad notifications and also mark them as received once any
|
||||
of their inner TLMessage is acknowledged.
|
||||
"""
|
||||
for i in reversed(range(len(self._pending_containers))):
|
||||
message = self._pending_containers[i]
|
||||
for msg in message.obj.messages:
|
||||
if msg.msg_id in msg_ids:
|
||||
del self._pending_containers[i]
|
||||
del self._pending_messages[message.msg_id]
|
||||
break
|
||||
|
||||
# Loops
|
||||
|
||||
async def _send_loop(self):
|
||||
"""
|
||||
This loop is responsible for popping items off the send
|
||||
queue, encrypting them, and sending them over the network.
|
||||
|
||||
Besides `connect`, only this method ever sends data.
|
||||
"""
|
||||
while self._user_connected and not self._reconnecting:
|
||||
if self._pending_ack:
|
||||
self._send_queue.put_nowait(self.state.create_message(
|
||||
MsgsAck(list(self._pending_ack))
|
||||
))
|
||||
self._pending_ack.clear()
|
||||
|
||||
messages = await self._send_queue.get()
|
||||
if isinstance(messages, list):
|
||||
message = self.state.create_message(MessageContainer(messages))
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._pending_containers.append(message)
|
||||
else:
|
||||
message = messages
|
||||
messages = [message]
|
||||
|
||||
__log__.debug('Packing {} outgoing message(s)...'
|
||||
.format(len(messages)))
|
||||
body = self.state.pack_message(message)
|
||||
|
||||
while not any(m.future.cancelled() for m in messages):
|
||||
try:
|
||||
__log__.debug('Sending {} bytes...'.format(len(body)))
|
||||
await self._connection.send(body)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except OSError as e:
|
||||
__log__.warning('OSError while sending %s', e)
|
||||
else:
|
||||
# Remove the cancelled messages from pending
|
||||
__log__.info('Some futures were cancelled, aborted send')
|
||||
self._clean_containers([m.msg_id for m in messages])
|
||||
for m in messages:
|
||||
if m.future.cancelled():
|
||||
self._pending_messages.pop(m.msg_id, None)
|
||||
else:
|
||||
self._send_queue.put_nowait(m)
|
||||
|
||||
__log__.debug('Outgoing messages {} sent!'
|
||||
.format(', '.join(str(m.msg_id) for m in messages)))
|
||||
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
This loop is responsible for reading all incoming responses
|
||||
from the network, decrypting and handling or dispatching them.
|
||||
|
||||
Besides `connect`, only this method ever receives data.
|
||||
"""
|
||||
while self._user_connected and not self._reconnecting:
|
||||
# TODO Are there more exceptions besides timeout?
|
||||
# Disconnecting or switching off WiFi only resulted in
|
||||
# timeouts, and once the network was back it continued
|
||||
# on its own after a short delay.
|
||||
try:
|
||||
__log__.debug('Receiving items from the network...')
|
||||
body = await self._connection.recv()
|
||||
except asyncio.TimeoutError:
|
||||
# TODO If nothing is received for a minute, send a request
|
||||
continue
|
||||
except ConnectionError as e:
|
||||
__log__.info('Connection reset while receiving %s', e)
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
break
|
||||
except OSError as e:
|
||||
__log__.warning('OSError while receiving %s', e)
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
break
|
||||
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
__log__.debug('Decoding packet of %d bytes...', len(body))
|
||||
try:
|
||||
message = self.state.unpack_message(body)
|
||||
except (BrokenAuthKeyError, BufferError) as e:
|
||||
# The authorization key may be broken if a message was
|
||||
# sent malformed, or if the authkey truly is corrupted.
|
||||
#
|
||||
# There may be a buffer error if Telegram's response was too
|
||||
# short and hence not understood. Reset the authorization key
|
||||
# and try again in either case.
|
||||
#
|
||||
# TODO Is it possible to detect malformed messages vs
|
||||
# an actually broken authkey?
|
||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
||||
self.state.auth_key = None
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
break
|
||||
except SecurityError as e:
|
||||
# A step while decoding had the incorrect data. This message
|
||||
# should not be considered safe and it should be ignored.
|
||||
__log__.warning('Security error while unpacking a '
|
||||
'received message:'.format(e))
|
||||
continue
|
||||
except TypeNotFoundError as e:
|
||||
# The payload inside the message was not a known TLObject.
|
||||
__log__.info('Server replied with an unknown type {:08x}: {!r}'
|
||||
.format(e.invalid_constructor_id, e.remaining))
|
||||
else:
|
||||
await self._process_message(message)
|
||||
|
||||
# Response Handlers
|
||||
|
||||
async def _process_message(self, message):
|
||||
"""
|
||||
Adds the given message to the list of messages that must be
|
||||
acknowledged and dispatches control to different ``_handle_*``
|
||||
method based on its type.
|
||||
"""
|
||||
self._pending_ack.add(message.msg_id)
|
||||
handler = self._handlers.get(message.obj.CONSTRUCTOR_ID,
|
||||
self._handle_update)
|
||||
await handler(message)
|
||||
|
||||
async def _handle_rpc_result(self, message):
|
||||
"""
|
||||
Handles the result for Remote Procedure Calls:
|
||||
|
||||
rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult;
|
||||
|
||||
This is where the future results for sent requests are set.
|
||||
"""
|
||||
rpc_result = message.obj
|
||||
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
|
||||
__log__.debug('Handling RPC result for message {}'
|
||||
.format(rpc_result.req_msg_id))
|
||||
|
||||
if rpc_result.error:
|
||||
# TODO Report errors if possible/enabled
|
||||
error = rpc_message_to_error(rpc_result.error)
|
||||
self._send_queue.put_nowait(self.state.create_message(
|
||||
MsgsAck([message.msg_id])
|
||||
))
|
||||
|
||||
if not message.future.cancelled():
|
||||
message.future.set_exception(error)
|
||||
return
|
||||
elif message:
|
||||
with BinaryReader(rpc_result.body) as reader:
|
||||
result = message.obj.read_result(reader)
|
||||
|
||||
# TODO Process entities
|
||||
if not message.future.cancelled():
|
||||
message.future.set_result(result)
|
||||
return
|
||||
else:
|
||||
# TODO We should not get responses to things we never sent
|
||||
__log__.info('Received response without parent request: {}'
|
||||
.format(rpc_result.body))
|
||||
|
||||
async def _handle_container(self, message):
|
||||
"""
|
||||
Processes the inner messages of a container with many of them:
|
||||
|
||||
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
|
||||
"""
|
||||
__log__.debug('Handling container')
|
||||
for inner_message in message.obj.messages:
|
||||
await self._process_message(inner_message)
|
||||
|
||||
async def _handle_gzip_packed(self, message):
|
||||
"""
|
||||
Unpacks the data from a gzipped object and processes it:
|
||||
|
||||
gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
"""
|
||||
__log__.debug('Handling gzipped data')
|
||||
with BinaryReader(message.obj.data) as reader:
|
||||
message.obj = reader.tgread_object()
|
||||
await self._process_message(message)
|
||||
|
||||
async def _handle_update(self, message):
|
||||
__log__.debug('Handling update {}'
|
||||
.format(message.obj.__class__.__name__))
|
||||
if self._update_callback:
|
||||
self._update_callback(message.obj)
|
||||
|
||||
async def _handle_pong(self, message):
|
||||
"""
|
||||
Handles pong results, which don't come inside a ``rpc_result``
|
||||
but are still sent through a request:
|
||||
|
||||
pong#347773c5 msg_id:long ping_id:long = Pong;
|
||||
"""
|
||||
__log__.debug('Handling pong')
|
||||
pong = message.obj
|
||||
message = self._pending_messages.pop(pong.msg_id, None)
|
||||
if message:
|
||||
message.future.set_result(pong)
|
||||
|
||||
async def _handle_bad_server_salt(self, message):
|
||||
"""
|
||||
Corrects the currently used server salt to use the right value
|
||||
before enqueuing the rejected message to be re-sent:
|
||||
|
||||
bad_server_salt#edab447b bad_msg_id:long bad_msg_seqno:int
|
||||
error_code:int new_server_salt:long = BadMsgNotification;
|
||||
"""
|
||||
__log__.debug('Handling bad salt')
|
||||
bad_salt = message.obj
|
||||
self.state.salt = bad_salt.new_server_salt
|
||||
self._send_queue.put_nowait(self._pending_messages[bad_salt.bad_msg_id])
|
||||
|
||||
async def _handle_bad_notification(self, message):
|
||||
"""
|
||||
Adjusts the current state to be correct based on the
|
||||
received bad message notification whenever possible:
|
||||
|
||||
bad_msg_notification#a7eff811 bad_msg_id:long bad_msg_seqno:int
|
||||
error_code:int = BadMsgNotification;
|
||||
"""
|
||||
__log__.debug('Handling bad message')
|
||||
bad_msg = message.obj
|
||||
if bad_msg.error_code in (16, 17):
|
||||
# Sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
self.state.update_time_offset(correct_msg_id=message.msg_id)
|
||||
elif bad_msg.error_code == 32:
|
||||
# msg_seqno too low, so just pump it up by some "large" amount
|
||||
# TODO A better fix would be to start with a new fresh session ID
|
||||
self.state._sequence += 64
|
||||
elif bad_msg.error_code == 33:
|
||||
# msg_seqno too high never seems to happen but just in case
|
||||
self.state._sequence -= 16
|
||||
else:
|
||||
msg = self._pending_messages.pop(bad_msg.bad_msg_id, None)
|
||||
if msg:
|
||||
msg.future.set_exception(BadMessageError(bad_msg.error_code))
|
||||
return
|
||||
|
||||
# Messages are to be re-sent once we've corrected the issue
|
||||
self._send_queue.put_nowait(self._pending_messages[bad_msg.bad_msg_id])
|
||||
|
||||
async def _handle_detailed_info(self, message):
|
||||
"""
|
||||
Updates the current status with the received detailed information:
|
||||
|
||||
msg_detailed_info#276d3ec6 msg_id:long answer_msg_id:long
|
||||
bytes:int status:int = MsgDetailedInfo;
|
||||
"""
|
||||
# TODO https://goo.gl/VvpCC6
|
||||
__log__.debug('Handling detailed info')
|
||||
self._pending_ack.add(message.obj.answer_msg_id)
|
||||
|
||||
async def _handle_new_detailed_info(self, message):
|
||||
"""
|
||||
Updates the current status with the received detailed information:
|
||||
|
||||
msg_new_detailed_info#809db6df answer_msg_id:long
|
||||
bytes:int status:int = MsgDetailedInfo;
|
||||
"""
|
||||
# TODO https://goo.gl/G7DPsR
|
||||
__log__.debug('Handling new detailed info')
|
||||
self._pending_ack.add(message.obj.answer_msg_id)
|
||||
|
||||
async def _handle_new_session_created(self, message):
|
||||
"""
|
||||
Updates the current status with the received session information:
|
||||
|
||||
new_session_created#9ec20908 first_msg_id:long unique_id:long
|
||||
server_salt:long = NewSession;
|
||||
"""
|
||||
# TODO https://goo.gl/LMyN7A
|
||||
__log__.debug('Handling new session created')
|
||||
self.state.salt = message.obj.server_salt
|
||||
|
||||
async def _handle_ack(self, message):
|
||||
"""
|
||||
Handles a server acknowledge about our messages. Normally
|
||||
these can be ignored except in the case of ``auth.logOut``:
|
||||
|
||||
auth.logOut#5717da40 = Bool;
|
||||
|
||||
Telegram doesn't seem to send its result so we need to confirm
|
||||
it manually. No other request is known to have this behaviour.
|
||||
|
||||
Since the ID of sent messages consisting of a container is
|
||||
never returned (unless on a bad notification), this method
|
||||
also removes containers messages when any of their inner
|
||||
messages are acknowledged.
|
||||
"""
|
||||
__log__.debug('Handling acknowledge')
|
||||
ack = message.obj
|
||||
if self._pending_containers:
|
||||
self._clean_containers(ack.msg_ids)
|
||||
|
||||
for msg_id in ack.msg_ids:
|
||||
msg = self._pending_messages.get(msg_id, None)
|
||||
if msg and isinstance(msg.obj, LogOutRequest):
|
||||
del self._pending_messages[msg_id]
|
||||
msg.future.set_result(True)
|
||||
|
||||
async def _handle_future_salts(self, message):
|
||||
"""
|
||||
Handles future salt results, which don't come inside a
|
||||
``rpc_result`` but are still sent through a request:
|
||||
|
||||
future_salts#ae500895 req_msg_id:long now:int
|
||||
salts:vector<future_salt> = FutureSalts;
|
||||
"""
|
||||
# TODO save these salts and automatically adjust to the
|
||||
# correct one whenever the salt in use expires.
|
||||
__log__.debug('Handling future salts')
|
||||
msg = self._pending_messages.pop(message.msg_id, None)
|
||||
if msg:
|
||||
msg.future.set_result(message.obj)
|
||||
|
||||
async def _handle_state_forgotten(self, message):
|
||||
"""
|
||||
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
|
||||
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
|
||||
"""
|
||||
self.send(MsgsStateInfo(req_msg_id=message.msg_id,
|
||||
info=chr(1) * len(message.obj.msg_ids)))
|
||||
|
||||
async def _handle_msg_all(self, message):
|
||||
"""
|
||||
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
|
||||
"""
|
||||
|
||||
|
||||
class _ContainerQueue(asyncio.Queue):
|
||||
"""
|
||||
An asyncio queue that's aware of `MessageContainer` instances.
|
||||
|
||||
The `get` method returns either a single `TLMessage` or a list
|
||||
of them that should be turned into a new `MessageContainer`.
|
||||
|
||||
Instances of this class can be replaced with the simpler
|
||||
``asyncio.Queue`` when needed for testing purposes, and
|
||||
a list won't be returned in said case.
|
||||
"""
|
||||
async def get(self):
|
||||
result = await super().get()
|
||||
if self.empty() or isinstance(result.obj, MessageContainer):
|
||||
return result
|
||||
|
||||
result = [result]
|
||||
while not self.empty():
|
||||
item = self.get_nowait()
|
||||
if isinstance(item.obj, MessageContainer):
|
||||
self.put_nowait(item)
|
||||
break
|
||||
else:
|
||||
result.append(item)
|
||||
|
||||
return result
|
163
telethon/network/mtprotostate.py
Normal file
163
telethon/network/mtprotostate.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
import logging
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from hashlib import sha256
|
||||
|
||||
from ..crypto import AES
|
||||
from ..errors import SecurityError, BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl.core import TLMessage
|
||||
from ..tl.tlobject import TLRequest
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MTProtoState:
|
||||
"""
|
||||
`telethon.network.mtprotosender.MTProtoSender` needs to hold a state
|
||||
in order to be able to encrypt and decrypt incoming/outgoing messages,
|
||||
as well as generating the message IDs. Instances of this class hold
|
||||
together all the required information.
|
||||
|
||||
It doesn't make sense to use `telethon.sessions.abstract.Session` for
|
||||
the sender because the sender should *not* be concerned about storing
|
||||
this information to disk, as one may create as many senders as they
|
||||
desire to any other data center, or some CDN. Using the same session
|
||||
for all these is not a good idea as each need their own authkey, and
|
||||
the concept of "copying" sessions with the unnecessary entities or
|
||||
updates state for these connections doesn't make sense.
|
||||
"""
|
||||
def __init__(self, auth_key):
|
||||
# Session IDs can be random on every connection
|
||||
self.id = struct.unpack('q', os.urandom(8))[0]
|
||||
self.auth_key = auth_key
|
||||
self.time_offset = 0
|
||||
self.salt = 0
|
||||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
|
||||
def create_message(self, obj, after=None):
|
||||
"""
|
||||
Creates a new `telethon.tl.tl_message.TLMessage` from
|
||||
the given `telethon.tl.tlobject.TLObject` instance.
|
||||
"""
|
||||
return TLMessage(
|
||||
msg_id=self._get_new_msg_id(),
|
||||
seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
|
||||
obj=obj,
|
||||
after_id=after.msg_id if after else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _calc_key(auth_key, msg_key, client):
|
||||
"""
|
||||
Calculate the key based on Telegram guidelines for MTProto 2,
|
||||
specifying whether it's the client or not. See
|
||||
https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||
"""
|
||||
x = 0 if client else 8
|
||||
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
|
||||
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
|
||||
|
||||
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
|
||||
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
|
||||
|
||||
return aes_key, aes_iv
|
||||
|
||||
def pack_message(self, message):
|
||||
"""
|
||||
Packs the given `telethon.tl.tl_message.TLMessage` using the
|
||||
current authorization key following MTProto 2.0 guidelines.
|
||||
|
||||
See https://core.telegram.org/mtproto/description.
|
||||
"""
|
||||
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
|
||||
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
||||
|
||||
# Being substr(what, offset, length); x = 0 for client
|
||||
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
|
||||
msg_key_large = sha256(
|
||||
self.auth_key.key[88:88 + 32] + data + padding).digest()
|
||||
|
||||
# "msg_key = substr (msg_key_large, 8, 16)"
|
||||
msg_key = msg_key_large[8:24]
|
||||
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, True)
|
||||
|
||||
key_id = struct.pack('<Q', self.auth_key.key_id)
|
||||
return (key_id + msg_key +
|
||||
AES.encrypt_ige(data + padding, aes_key, aes_iv))
|
||||
|
||||
def unpack_message(self, body):
|
||||
"""
|
||||
Inverse of `pack_message` for incoming server messages.
|
||||
"""
|
||||
if len(body) < 8:
|
||||
if body == b'l\xfe\xff\xff':
|
||||
raise BrokenAuthKeyError()
|
||||
else:
|
||||
raise BufferError("Can't decode packet ({})".format(body))
|
||||
|
||||
key_id = struct.unpack('<Q', body[:8])[0]
|
||||
if key_id != self.auth_key.key_id:
|
||||
raise SecurityError('Server replied with an invalid auth key')
|
||||
|
||||
msg_key = body[8:24]
|
||||
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False)
|
||||
body = AES.decrypt_ige(body[24:], aes_key, aes_iv)
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines
|
||||
# Sections "checking sha256 hash" and "message length"
|
||||
our_key = sha256(self.auth_key.key[96:96 + 32] + body)
|
||||
if msg_key != our_key.digest()[8:24]:
|
||||
raise SecurityError(
|
||||
"Received msg_key doesn't match with expected one")
|
||||
|
||||
reader = BinaryReader(body)
|
||||
reader.read_long() # remote_salt
|
||||
if reader.read_long() != self.id:
|
||||
raise SecurityError('Server replied with a wrong session ID')
|
||||
|
||||
remote_msg_id = reader.read_long()
|
||||
remote_sequence = reader.read_int()
|
||||
reader.read_int() # msg_len for the inner object, padding ignored
|
||||
obj = reader.tgread_object()
|
||||
|
||||
return TLMessage(remote_msg_id, remote_sequence, obj)
|
||||
|
||||
def _get_new_msg_id(self):
|
||||
"""
|
||||
Generates a new unique message ID based on the current
|
||||
time (in ms) since epoch, applying a known time offset.
|
||||
"""
|
||||
now = time.time() + self.time_offset
|
||||
nanoseconds = int((now - int(now)) * 1e+9)
|
||||
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
|
||||
|
||||
if self._last_msg_id >= new_msg_id:
|
||||
new_msg_id = self._last_msg_id + 4
|
||||
|
||||
self._last_msg_id = new_msg_id
|
||||
return new_msg_id
|
||||
|
||||
def update_time_offset(self, correct_msg_id):
|
||||
"""
|
||||
Updates the time offset to the correct
|
||||
one given a known valid message ID.
|
||||
"""
|
||||
now = int(time.time())
|
||||
correct = correct_msg_id >> 32
|
||||
self.time_offset = correct - now
|
||||
self._last_msg_id = 0
|
||||
|
||||
def _get_seq_no(self, content_related):
|
||||
"""
|
||||
Generates the next sequence number depending on whether
|
||||
it should be for a content-related query or not.
|
||||
"""
|
||||
if content_related:
|
||||
result = self._sequence * 2 + 1
|
||||
self._sequence += 1
|
||||
return result
|
||||
else:
|
||||
return self._sequence * 2
|
|
@ -1,757 +0,0 @@
|
|||
import logging
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
from datetime import timedelta, datetime
|
||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
from . import version, utils
|
||||
from .crypto import rsa
|
||||
from .errors import (
|
||||
RPCError, BrokenAuthKeyError, ServerError, FloodWaitError,
|
||||
FloodTestPhoneWaitError, TypeNotFoundError, UnauthorizedError,
|
||||
PhoneMigrateError, NetworkMigrateError, UserMigrateError, AuthKeyError,
|
||||
RpcCallFailError
|
||||
)
|
||||
from .network import authenticator, MtProtoSender, ConnectionTcpFull
|
||||
from .sessions import Session, SQLiteSession
|
||||
from .tl import TLObject
|
||||
from .tl.all_tlobjects import LAYER
|
||||
from .tl.functions import (
|
||||
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
|
||||
)
|
||||
from .tl.functions.auth import (
|
||||
ImportAuthorizationRequest, ExportAuthorizationRequest
|
||||
)
|
||||
from .tl.functions.help import (
|
||||
GetCdnConfigRequest, GetConfigRequest
|
||||
)
|
||||
from .tl.functions.updates import GetStateRequest
|
||||
from .tl.types.auth import ExportedAuthorization
|
||||
from .update_state import UpdateState
|
||||
|
||||
DEFAULT_DC_ID = 4
|
||||
DEFAULT_IPV4_IP = '149.154.167.51'
|
||||
DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]'
|
||||
DEFAULT_PORT = 443
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramBareClient:
|
||||
"""Bare Telegram Client with just the minimum -
|
||||
|
||||
The reason to distinguish between a MtProtoSender and a
|
||||
TelegramClient itself is because the sender is just that,
|
||||
a sender, which should know nothing about Telegram but
|
||||
rather how to handle this specific connection.
|
||||
|
||||
The TelegramClient itself should know how to initialize
|
||||
a proper connection to the servers, as well as other basic
|
||||
methods such as disconnection and reconnection.
|
||||
|
||||
This distinction between a bare client and a full client
|
||||
makes it possible to create clones of the bare version
|
||||
(by using the same session, IP address and port) to be
|
||||
able to execute queries on either, without the additional
|
||||
cost that would involve having the methods for signing in,
|
||||
logging out, and such.
|
||||
"""
|
||||
|
||||
# Current TelegramClient version
|
||||
__version__ = version.__version__
|
||||
|
||||
# TODO Make this thread-safe, all connections share the same DC
|
||||
_config = None # Server configuration (with .dc_options)
|
||||
|
||||
# region Initialization
|
||||
|
||||
def __init__(self, session, api_id, api_hash,
|
||||
*,
|
||||
connection=ConnectionTcpFull,
|
||||
use_ipv6=False,
|
||||
proxy=None,
|
||||
update_workers=None,
|
||||
spawn_read_thread=False,
|
||||
timeout=timedelta(seconds=5),
|
||||
report_errors=True,
|
||||
device_model=None,
|
||||
system_version=None,
|
||||
app_version=None,
|
||||
lang_code='en',
|
||||
system_lang_code='en'):
|
||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||
if not api_id or not api_hash:
|
||||
raise ValueError(
|
||||
"Your API ID or Hash cannot be empty or None. "
|
||||
"Refer to telethon.rtfd.io for more information.")
|
||||
|
||||
self._use_ipv6 = use_ipv6
|
||||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
session = SQLiteSession(session)
|
||||
elif not isinstance(session, Session):
|
||||
raise TypeError(
|
||||
'The given session must be a str or a Session instance.'
|
||||
)
|
||||
|
||||
# ':' in session.server_address is True if it's an IPv6 address
|
||||
if (not session.server_address or
|
||||
(':' in session.server_address) != use_ipv6):
|
||||
session.set_dc(
|
||||
DEFAULT_DC_ID,
|
||||
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
|
||||
DEFAULT_PORT
|
||||
)
|
||||
|
||||
session.report_errors = report_errors
|
||||
self.session = session
|
||||
self.api_id = int(api_id)
|
||||
self.api_hash = api_hash
|
||||
|
||||
# This is the main sender, which will be used from the thread
|
||||
# that calls .connect(). Every other thread will spawn a new
|
||||
# temporary connection. The connection on this one is always
|
||||
# kept open so Telegram can send us updates.
|
||||
if isinstance(connection, type):
|
||||
connection = connection(proxy=proxy, timeout=timeout)
|
||||
|
||||
self._sender = MtProtoSender(self.session, connection)
|
||||
|
||||
# Two threads may be calling reconnect() when the connection is lost,
|
||||
# we only want one to actually perform the reconnection.
|
||||
self._reconnect_lock = Lock()
|
||||
|
||||
# Cache "exported" sessions as 'dc_id: Session' not to recreate
|
||||
# them all the time since generating a new key is a relatively
|
||||
# expensive operation.
|
||||
self._exported_sessions = {}
|
||||
|
||||
# This member will process updates if enabled.
|
||||
# One may change self.updates.enabled at any later point.
|
||||
self.updates = UpdateState(workers=update_workers)
|
||||
|
||||
# Used on connection - the user may modify these and reconnect
|
||||
system = platform.uname()
|
||||
self.device_model = device_model or system.system or 'Unknown'
|
||||
self.system_version = system_version or system.release or '1.0'
|
||||
self.app_version = app_version or self.__version__
|
||||
self.lang_code = lang_code
|
||||
self.system_lang_code = system_lang_code
|
||||
|
||||
# Despite the state of the real connection, keep track of whether
|
||||
# the user has explicitly called .connect() or .disconnect() here.
|
||||
# This information is required by the read thread, who will be the
|
||||
# one attempting to reconnect on the background *while* the user
|
||||
# doesn't explicitly call .disconnect(), thus telling it to stop
|
||||
# retrying. The main thread, knowing there is a background thread
|
||||
# attempting reconnection as soon as it happens, will just sleep.
|
||||
self._user_connected = False
|
||||
|
||||
# Save whether the user is authorized here (a.k.a. logged in)
|
||||
self._authorized = None # None = We don't know yet
|
||||
|
||||
# The first request must be in invokeWithLayer(initConnection(X)).
|
||||
# See https://core.telegram.org/api/invoking#saving-client-info.
|
||||
self._first_request = True
|
||||
|
||||
# Constantly read for results and updates from within the main client,
|
||||
# if the user has left enabled such option.
|
||||
self._spawn_read_thread = spawn_read_thread
|
||||
self._recv_thread = None
|
||||
self._idling = threading.Event()
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
# Also have another delay for GetStateRequest.
|
||||
#
|
||||
# If the connection is kept alive for long without invoking any
|
||||
# high level request the server simply stops sending updates.
|
||||
# TODO maybe we can have ._last_request instead if any req works?
|
||||
self._last_state = datetime.now()
|
||||
self._state_delay = timedelta(hours=1)
|
||||
|
||||
# Some errors are known but there's nothing we can do from the
|
||||
# background thread. If any of these happens, call .disconnect(),
|
||||
# and raise them next time .invoke() is tried to be called.
|
||||
self._background_error = None
|
||||
|
||||
# endregion
|
||||
|
||||
# region Connecting
|
||||
|
||||
def connect(self, _sync_updates=True):
|
||||
"""Connects to the Telegram servers, executing authentication if
|
||||
required. Note that authenticating to the Telegram servers is
|
||||
not the same as authenticating the desired user itself, which
|
||||
may require a call (or several) to 'sign_in' for the first time.
|
||||
|
||||
Note that the optional parameters are meant for internal use.
|
||||
|
||||
If '_sync_updates', sync_updates() will be called and a
|
||||
second thread will be started if necessary. Note that this
|
||||
will FAIL if the client is not connected to the user's
|
||||
native data center, raising a "UserMigrateError", and
|
||||
calling .disconnect() in the process.
|
||||
"""
|
||||
__log__.info('Connecting to %s:%d...',
|
||||
self.session.server_address, self.session.port)
|
||||
|
||||
self._background_error = None # Clear previous errors
|
||||
|
||||
try:
|
||||
self._sender.connect()
|
||||
__log__.info('Connection success!')
|
||||
|
||||
# Connection was successful! Try syncing the update state
|
||||
# UNLESS '_sync_updates' is False (we probably are in
|
||||
# another data center and this would raise UserMigrateError)
|
||||
# to also assert whether the user is logged in or not.
|
||||
self._user_connected = True
|
||||
if self._authorized is None and _sync_updates:
|
||||
try:
|
||||
self.sync_updates()
|
||||
self._set_connected_and_authorized()
|
||||
except UnauthorizedError:
|
||||
self._authorized = False
|
||||
elif self._authorized:
|
||||
self._set_connected_and_authorized()
|
||||
|
||||
return True
|
||||
|
||||
except TypeNotFoundError as e:
|
||||
# This is fine, probably layer migration
|
||||
__log__.warning('Connection failed, got unexpected type with ID '
|
||||
'%s. Migrating?', hex(e.invalid_constructor_id))
|
||||
self.disconnect()
|
||||
return self.connect(_sync_updates=_sync_updates)
|
||||
|
||||
except AuthKeyError as e:
|
||||
# As of late March 2018 there were two AUTH_KEY_DUPLICATED
|
||||
# reports. Retrying with a clean auth_key should fix this.
|
||||
__log__.warning('Auth key error %s. Clearing it and retrying.', e)
|
||||
self.disconnect()
|
||||
self.session.auth_key = None
|
||||
self.session.save()
|
||||
return self.connect(_sync_updates=_sync_updates)
|
||||
|
||||
except (RPCError, ConnectionError) as e:
|
||||
# Probably errors from the previous session, ignore them
|
||||
__log__.error('Connection failed due to %s', e)
|
||||
self.disconnect()
|
||||
return False
|
||||
|
||||
def is_connected(self):
|
||||
return self._sender.is_connected()
|
||||
|
||||
def _wrap_init_connection(self, query):
|
||||
"""Wraps query around InvokeWithLayerRequest(InitConnectionRequest())"""
|
||||
return InvokeWithLayerRequest(LAYER, InitConnectionRequest(
|
||||
api_id=self.api_id,
|
||||
device_model=self.device_model,
|
||||
system_version=self.system_version,
|
||||
app_version=self.app_version,
|
||||
lang_code=self.lang_code,
|
||||
system_lang_code=self.system_lang_code,
|
||||
lang_pack='', # "langPacks are for official apps only"
|
||||
query=query
|
||||
))
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from the Telegram server
|
||||
and stops all the spawned threads"""
|
||||
__log__.info('Disconnecting...')
|
||||
self._user_connected = False # This will stop recv_thread's loop
|
||||
|
||||
__log__.debug('Stopping all workers...')
|
||||
self.updates.stop_workers()
|
||||
|
||||
# This will trigger a "ConnectionResetError" on the recv_thread,
|
||||
# which won't attempt reconnecting as ._user_connected is False.
|
||||
__log__.debug('Disconnecting the socket...')
|
||||
self._sender.disconnect()
|
||||
|
||||
# TODO Shall we clear the _exported_sessions, or may be reused?
|
||||
self._first_request = True # On reconnect it will be first again
|
||||
self.session.set_update_state(0, self.updates.get_update_state(0))
|
||||
self.session.close()
|
||||
|
||||
def _reconnect(self, new_dc=None):
|
||||
"""If 'new_dc' is not set, only a call to .connect() will be made
|
||||
since it's assumed that the connection has been lost and the
|
||||
library is reconnecting.
|
||||
|
||||
If 'new_dc' is set, the client is first disconnected from the
|
||||
current data center, clears the auth key for the old DC, and
|
||||
connects to the new data center.
|
||||
"""
|
||||
if new_dc is None:
|
||||
if self.is_connected():
|
||||
__log__.info('Reconnection aborted: already connected')
|
||||
return True
|
||||
|
||||
try:
|
||||
__log__.info('Attempting reconnection...')
|
||||
return self.connect()
|
||||
except ConnectionResetError as e:
|
||||
__log__.warning('Reconnection failed due to %s', e)
|
||||
return False
|
||||
else:
|
||||
# Since we're reconnecting possibly due to a UserMigrateError,
|
||||
# we need to first know the Data Centers we can connect to. Do
|
||||
# that before disconnecting.
|
||||
dc = self._get_dc(new_dc)
|
||||
__log__.info('Reconnecting to new data center %s', dc)
|
||||
|
||||
self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
# auth_key's are associated with a server, which has now changed
|
||||
# so it's not valid anymore. Set to None to force recreating it.
|
||||
self.session.auth_key = None
|
||||
self.session.save()
|
||||
self.disconnect()
|
||||
return self.connect()
|
||||
|
||||
def set_proxy(self, proxy):
|
||||
"""Change the proxy used by the connections.
|
||||
"""
|
||||
if self.is_connected():
|
||||
raise RuntimeError("You can't change the proxy while connected.")
|
||||
self._sender.connection.conn.proxy = proxy
|
||||
|
||||
# endregion
|
||||
|
||||
# region Working with different connections/Data Centers
|
||||
|
||||
def _on_read_thread(self):
|
||||
return self._recv_thread is not None and \
|
||||
threading.get_ident() == self._recv_thread.ident
|
||||
|
||||
def _get_dc(self, dc_id, cdn=False):
|
||||
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
||||
if not TelegramBareClient._config:
|
||||
TelegramBareClient._config = self(GetConfigRequest())
|
||||
|
||||
try:
|
||||
if cdn:
|
||||
# Ensure we have the latest keys for the CDNs
|
||||
for pk in self(GetCdnConfigRequest()).public_keys:
|
||||
rsa.add_key(pk.public_key)
|
||||
|
||||
return next(
|
||||
dc for dc in TelegramBareClient._config.dc_options
|
||||
if dc.id == dc_id and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
|
||||
)
|
||||
except StopIteration:
|
||||
if not cdn:
|
||||
raise
|
||||
|
||||
# New configuration, perhaps a new CDN was added?
|
||||
TelegramBareClient._config = self(GetConfigRequest())
|
||||
return self._get_dc(dc_id, cdn=cdn)
|
||||
|
||||
def _get_exported_client(self, dc_id):
|
||||
"""Creates and connects a new TelegramBareClient for the desired DC.
|
||||
|
||||
If it's the first time calling the method with a given dc_id,
|
||||
a new session will be first created, and its auth key generated.
|
||||
Exporting/Importing the authorization will also be done so that
|
||||
the auth is bound with the key.
|
||||
"""
|
||||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization! ^^
|
||||
session = self._exported_sessions.get(dc_id)
|
||||
if session:
|
||||
export_auth = None # Already bound with the auth key
|
||||
else:
|
||||
# TODO Add a lock, don't allow two threads to create an auth key
|
||||
# (when calling .connect() if there wasn't a previous session).
|
||||
# for the same data center.
|
||||
dc = self._get_dc(dc_id)
|
||||
|
||||
# Export the current authorization to the new DC.
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
export_auth = self(ExportAuthorizationRequest(dc_id))
|
||||
|
||||
# Create a temporary session for this IP address, which needs
|
||||
# to be different because each auth_key is unique per DC.
|
||||
#
|
||||
# Construct this session with the connection parameters
|
||||
# (system version, device model...) from the current one.
|
||||
session = self.session.clone()
|
||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
self._exported_sessions[dc_id] = session
|
||||
|
||||
__log__.info('Creating exported new client')
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
client.connect(_sync_updates=False)
|
||||
if isinstance(export_auth, ExportedAuthorization):
|
||||
client(ImportAuthorizationRequest(
|
||||
id=export_auth.id, bytes=export_auth.bytes
|
||||
))
|
||||
elif export_auth is not None:
|
||||
__log__.warning('Unknown export auth type %s', export_auth)
|
||||
|
||||
client._authorized = True # We exported the auth, so we got auth
|
||||
return client
|
||||
|
||||
def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._get_exported_client, but for CDNs"""
|
||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
if not session:
|
||||
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = self.session.clone()
|
||||
session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||
|
||||
__log__.info('Creating new CDN client')
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
|
||||
# This will make use of the new RSA keys for this specific CDN.
|
||||
#
|
||||
# We won't be calling GetConfigRequest because it's only called
|
||||
# when needed by ._get_dc, and also it's static so it's likely
|
||||
# set already. Avoid invoking non-CDN methods by not syncing updates.
|
||||
client.connect(_sync_updates=False)
|
||||
client._authorized = self._authorized
|
||||
return client
|
||||
|
||||
# endregion
|
||||
|
||||
# region Invoking Telegram requests
|
||||
|
||||
def __call__(self, request, retries=5, ordered=False):
|
||||
"""
|
||||
Invokes (sends) one or more MTProtoRequests and returns (receives)
|
||||
their result.
|
||||
|
||||
Args:
|
||||
request (`TLObject` | `list`):
|
||||
The request or requests to be invoked.
|
||||
|
||||
retries (`bool`, optional):
|
||||
How many times the request should be retried automatically
|
||||
in case it fails with a non-RPC error.
|
||||
|
||||
The invoke will be retried up to 'retries' times before raising
|
||||
``RuntimeError``.
|
||||
|
||||
ordered (`bool`, optional):
|
||||
Whether the requests (if more than one was given) should be
|
||||
executed sequentially on the server. They run in arbitrary
|
||||
order by default.
|
||||
|
||||
Returns:
|
||||
The result of the request (often a `TLObject`) or a list of
|
||||
results if more than one request was given.
|
||||
"""
|
||||
single = not utils.is_list_like(request)
|
||||
if single:
|
||||
request = (request,)
|
||||
|
||||
if not all(isinstance(x, TLObject) and
|
||||
x.content_related for x in request):
|
||||
raise TypeError('You can only invoke requests, not types!')
|
||||
|
||||
if self._background_error:
|
||||
raise self._background_error
|
||||
|
||||
for r in request:
|
||||
r.resolve(self, utils)
|
||||
|
||||
# For logging purposes
|
||||
if single:
|
||||
which = type(request[0]).__name__
|
||||
else:
|
||||
which = '{} requests ({})'.format(
|
||||
len(request), [type(x).__name__ for x in request])
|
||||
|
||||
# Determine the sender to be used (main or a new connection)
|
||||
__log__.debug('Invoking %s', which)
|
||||
call_receive = \
|
||||
not self._idling.is_set() or self._reconnect_lock.locked()
|
||||
|
||||
for retry in range(retries):
|
||||
result = self._invoke(call_receive, request, ordered=ordered)
|
||||
if result is not None:
|
||||
return result[0] if single else result
|
||||
|
||||
log = __log__.info if retry == 0 else __log__.warning
|
||||
log('Invoking %s failed %d times, connecting again and retrying',
|
||||
which, retry + 1)
|
||||
|
||||
sleep(1)
|
||||
# The ReadThread has priority when attempting reconnection,
|
||||
# since this thread is constantly running while __call__ is
|
||||
# only done sometimes. Here try connecting only once/retry.
|
||||
if not self._reconnect_lock.locked():
|
||||
with self._reconnect_lock:
|
||||
self._reconnect()
|
||||
|
||||
raise RuntimeError('Number of retries reached 0 for {}.'.format(
|
||||
which
|
||||
))
|
||||
|
||||
# Let people use client.invoke(SomeRequest()) instead client(...)
|
||||
invoke = __call__
|
||||
|
||||
def _invoke(self, call_receive, requests, ordered=False):
|
||||
try:
|
||||
# Ensure that we start with no previous errors (i.e. resending)
|
||||
for x in requests:
|
||||
x.confirm_received.clear()
|
||||
x.rpc_error = None
|
||||
|
||||
if not self.session.auth_key:
|
||||
__log__.info('Need to generate new auth key before invoking')
|
||||
self._first_request = True
|
||||
self.session.auth_key, self.session.time_offset = \
|
||||
authenticator.do_authentication(self._sender.connection)
|
||||
|
||||
if self._first_request:
|
||||
__log__.info('Initializing a new connection while invoking')
|
||||
if len(requests) == 1:
|
||||
requests = [self._wrap_init_connection(requests[0])]
|
||||
else:
|
||||
# We need a SINGLE request (like GetConfig) to init conn.
|
||||
# Once that's done, the N original requests will be
|
||||
# invoked.
|
||||
TelegramBareClient._config = self(
|
||||
self._wrap_init_connection(GetConfigRequest())
|
||||
)
|
||||
|
||||
self._sender.send(requests, ordered=ordered)
|
||||
|
||||
if not call_receive:
|
||||
# TODO This will be slightly troublesome if we allow
|
||||
# switching between constant read or not on the fly.
|
||||
# Must also watch out for calling .read() from two places,
|
||||
# in which case a Lock would be required for .receive().
|
||||
for x in requests:
|
||||
x.confirm_received.wait(
|
||||
self._sender.connection.get_timeout()
|
||||
)
|
||||
else:
|
||||
while not all(x.confirm_received.is_set() for x in requests):
|
||||
self._sender.receive(update_state=self.updates)
|
||||
|
||||
except BrokenAuthKeyError:
|
||||
__log__.error('Authorization key seems broken and was invalid!')
|
||||
self.session.auth_key = None
|
||||
|
||||
except TypeNotFoundError as e:
|
||||
# Only occurs when we call receive. May happen when
|
||||
# we need to reconnect to another DC on login and
|
||||
# Telegram somehow sends old objects (like configOld)
|
||||
self._first_request = True
|
||||
__log__.warning('Read unknown TLObject code ({}). '
|
||||
'Setting again first_request flag.'
|
||||
.format(hex(e.invalid_constructor_id)))
|
||||
|
||||
except TimeoutError:
|
||||
__log__.warning('Invoking timed out') # We will just retry
|
||||
|
||||
except ConnectionResetError as e:
|
||||
__log__.warning('Connection was reset while invoking')
|
||||
if self._user_connected:
|
||||
# Server disconnected us, __call__ will try reconnecting.
|
||||
try:
|
||||
self._sender.disconnect()
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
else:
|
||||
# User never called .connect(), so raise this error.
|
||||
raise RuntimeError('Tried to invoke without .connect()') from e
|
||||
|
||||
# Clear the flag if we got this far
|
||||
self._first_request = False
|
||||
|
||||
try:
|
||||
raise next(x.rpc_error for x in requests if x.rpc_error)
|
||||
except StopIteration:
|
||||
if any(x.result is None for x in requests):
|
||||
# "A container may only be accepted or
|
||||
# rejected by the other party as a whole."
|
||||
return None
|
||||
|
||||
return [x.result for x in requests]
|
||||
|
||||
except (PhoneMigrateError, NetworkMigrateError,
|
||||
UserMigrateError) as e:
|
||||
|
||||
# TODO What happens with the background thread here?
|
||||
# For normal use cases, this won't happen, because this will only
|
||||
# be on the very first connection (not authorized, not running),
|
||||
# but may be an issue for people who actually travel?
|
||||
self._reconnect(new_dc=e.new_dc)
|
||||
return self._invoke(call_receive, requests)
|
||||
|
||||
except (ServerError, RpcCallFailError) as e:
|
||||
# Telegram is having some issues, just retry
|
||||
__log__.warning('Telegram is having internal issues: %s', e)
|
||||
|
||||
except (FloodWaitError, FloodTestPhoneWaitError) as e:
|
||||
__log__.warning('Request invoked too often, wait %ds', e.seconds)
|
||||
if e.seconds > self.session.flood_sleep_threshold | 0:
|
||||
raise
|
||||
|
||||
sleep(e.seconds)
|
||||
|
||||
# Some really basic functionality
|
||||
|
||||
def is_user_authorized(self):
|
||||
"""Has the user been authorized yet
|
||||
(code request sent and confirmed)?"""
|
||||
return self._authorized
|
||||
|
||||
def get_input_entity(self, peer):
|
||||
"""
|
||||
Stub method, no functionality so that calling
|
||||
``.get_input_entity()`` from ``.resolve()`` doesn't fail.
|
||||
"""
|
||||
return peer
|
||||
|
||||
# endregion
|
||||
|
||||
# region Updates handling
|
||||
|
||||
def sync_updates(self):
|
||||
"""Synchronizes self.updates to their initial state. Will be
|
||||
called automatically on connection if self.updates.enabled = True,
|
||||
otherwise it should be called manually after enabling updates.
|
||||
"""
|
||||
self.updates.process(self(GetStateRequest()))
|
||||
self._last_state = datetime.now()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Constant read
|
||||
|
||||
def _set_connected_and_authorized(self):
|
||||
self._authorized = True
|
||||
self.updates.setup_workers()
|
||||
if self._spawn_read_thread and self._recv_thread is None:
|
||||
self._recv_thread = threading.Thread(
|
||||
name='ReadThread', daemon=True,
|
||||
target=self._recv_thread_impl
|
||||
)
|
||||
self._recv_thread.start()
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
if self._user_connected:
|
||||
self.disconnect()
|
||||
else:
|
||||
os._exit(1)
|
||||
|
||||
def idle(self, stop_signals=(SIGINT, SIGTERM, SIGABRT)):
|
||||
"""
|
||||
Idles the program by looping forever and listening for updates
|
||||
until one of the signals are received, which breaks the loop.
|
||||
|
||||
:param stop_signals:
|
||||
Iterable containing signals from the signal module that will
|
||||
be subscribed to TelegramClient.disconnect() (effectively
|
||||
stopping the idle loop), which will be called on receiving one
|
||||
of those signals.
|
||||
:return:
|
||||
"""
|
||||
if self._spawn_read_thread and not self._on_read_thread():
|
||||
raise RuntimeError('Can only idle if spawn_read_thread=False')
|
||||
|
||||
self._idling.set()
|
||||
for sig in stop_signals:
|
||||
signal(sig, self._signal_handler)
|
||||
|
||||
if self._on_read_thread():
|
||||
__log__.info('Starting to wait for items from the network')
|
||||
else:
|
||||
__log__.info('Idling to receive items from the network')
|
||||
|
||||
while self._user_connected:
|
||||
try:
|
||||
if datetime.now() > self._last_ping + self._ping_delay:
|
||||
self._sender.send(PingRequest(
|
||||
int.from_bytes(os.urandom(8), 'big', signed=True)
|
||||
))
|
||||
self._last_ping = datetime.now()
|
||||
|
||||
if datetime.now() > self._last_state + self._state_delay:
|
||||
self._sender.send(GetStateRequest())
|
||||
self._last_state = datetime.now()
|
||||
|
||||
__log__.debug('Receiving items from the network...')
|
||||
self._sender.receive(update_state=self.updates)
|
||||
except TimeoutError:
|
||||
# No problem
|
||||
__log__.debug('Receiving items from the network timed out')
|
||||
except ConnectionResetError:
|
||||
if self._user_connected:
|
||||
__log__.error('Connection was reset while receiving '
|
||||
'items. Reconnecting')
|
||||
with self._reconnect_lock:
|
||||
while self._user_connected and not self._reconnect():
|
||||
sleep(0.1) # Retry forever, this is instant messaging
|
||||
|
||||
if self.is_connected():
|
||||
# Telegram seems to kick us every 1024 items received
|
||||
# from the network not considering things like bad salt.
|
||||
# We must execute some *high level* request (that's not
|
||||
# a ping) if we want to receive updates again.
|
||||
# TODO Test if getDifference works too (better alternative)
|
||||
self._sender.send(GetStateRequest())
|
||||
except:
|
||||
self._idling.clear()
|
||||
raise
|
||||
|
||||
self._idling.clear()
|
||||
__log__.info('Connection closed by the user, not reading anymore')
|
||||
|
||||
# By using this approach, another thread will be
|
||||
# created and started upon connection to constantly read
|
||||
# from the other end. Otherwise, manual calls to .receive()
|
||||
# must be performed. The MtProtoSender cannot be connected,
|
||||
# or an error will be thrown.
|
||||
#
|
||||
# This way, sending and receiving will be completely independent.
|
||||
def _recv_thread_impl(self):
|
||||
# This thread is "idle" (only listening for updates), but also
|
||||
# excepts everything unlike the manual idle because it should
|
||||
# not crash.
|
||||
while self._user_connected:
|
||||
try:
|
||||
self.idle(stop_signals=tuple())
|
||||
except Exception as error:
|
||||
__log__.exception('Unknown exception in the read thread! '
|
||||
'Disconnecting and leaving it to main thread')
|
||||
# Unknown exception, pass it to the main thread
|
||||
|
||||
try:
|
||||
import socks
|
||||
if isinstance(error, (
|
||||
socks.GeneralProxyError, socks.ProxyConnectionError
|
||||
)):
|
||||
# This is a known error, and it's not related to
|
||||
# Telegram but rather to the proxy. Disconnect and
|
||||
# hand it over to the main thread.
|
||||
self._background_error = error
|
||||
self.disconnect()
|
||||
break
|
||||
except ImportError:
|
||||
"Not using PySocks, so it can't be a proxy error"
|
||||
|
||||
self._recv_thread = None
|
||||
|
||||
# endregion
|
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1 @@
|
|||
from .tlobject import TLObject
|
||||
from .gzip_packed import GzipPacked
|
||||
from .tl_message import TLMessage
|
||||
from .message_container import MessageContainer
|
||||
from .tlobject import TLObject, TLRequest
|
||||
|
|
26
telethon/tl/core/__init__.py
Normal file
26
telethon/tl/core/__init__.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
This module holds core "special" types, which are more convenient ways
|
||||
to do stuff in a `telethon.network.mtprotosender.MTProtoSender` instance.
|
||||
|
||||
Only special cases are gzip-packed data, the response message (not a
|
||||
client message), the message container which references these messages
|
||||
and would otherwise conflict with the rest, and finally the RpcResult:
|
||||
|
||||
rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult;
|
||||
|
||||
Three things to note with this definition:
|
||||
1. The constructor ID is actually ``42d36c2c``.
|
||||
2. Those bytes are not read like the rest of bytes (length + payload).
|
||||
They are actually the raw bytes of another object, which can't be
|
||||
read directly because it depends on per-request information (since
|
||||
some can return ``Vector<int>`` and ``Vector<long>``).
|
||||
3. Those bytes may be gzipped data, which needs to be treated early.
|
||||
"""
|
||||
from .tlmessage import TLMessage
|
||||
from .gzippacked import GzipPacked
|
||||
from .messagecontainer import MessageContainer
|
||||
from .rpcresult import RpcResult
|
||||
|
||||
core_objects = {x.CONSTRUCTOR_ID: x for x in (
|
||||
GzipPacked, MessageContainer, RpcResult
|
||||
)}
|
|
@ -1,7 +1,7 @@
|
|||
import gzip
|
||||
import struct
|
||||
|
||||
from . import TLObject
|
||||
from .. import TLObject, TLRequest
|
||||
|
||||
|
||||
class GzipPacked(TLObject):
|
||||
|
@ -21,7 +21,7 @@ class GzipPacked(TLObject):
|
|||
"""
|
||||
data = bytes(request)
|
||||
# TODO This threshold could be configurable
|
||||
if request.content_related and len(data) > 512:
|
||||
if isinstance(request, TLRequest) and len(data) > 512:
|
||||
gzipped = bytes(GzipPacked(data))
|
||||
return gzipped if len(gzipped) < len(data) else data
|
||||
else:
|
||||
|
@ -36,3 +36,7 @@ class GzipPacked(TLObject):
|
|||
def read(reader):
|
||||
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
|
||||
return gzip.decompress(reader.tgread_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
return GzipPacked(gzip.decompress(reader.tgread_bytes()))
|
50
telethon/tl/core/messagecontainer.py
Normal file
50
telethon/tl/core/messagecontainer.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import logging
|
||||
import struct
|
||||
|
||||
from .tlmessage import TLMessage
|
||||
from ..tlobject import TLObject
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageContainer(TLObject):
|
||||
CONSTRUCTOR_ID = 0x73f1f8dc
|
||||
|
||||
def __init__(self, messages):
|
||||
self.messages = messages
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'messages':
|
||||
([] if self.messages is None else [
|
||||
None if x is None else x.to_dict() for x in self.messages
|
||||
]) if recursive else self.messages,
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
) + b''.join(bytes(m) for m in self.messages)
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
# This assumes that .read_* calls are done in the order they appear
|
||||
messages = []
|
||||
for _ in range(reader.read_int()):
|
||||
msg_id = reader.read_long()
|
||||
seq_no = reader.read_int()
|
||||
length = reader.read_int()
|
||||
before = reader.tell_position()
|
||||
obj = reader.tgread_object()
|
||||
messages.append(TLMessage(msg_id, seq_no, obj))
|
||||
if reader.tell_position() != before + length:
|
||||
reader.set_position(before)
|
||||
__log__.warning('Data left after TLObject {}: {!r}'
|
||||
.format(obj, reader.read(length)))
|
||||
return MessageContainer(messages)
|
23
telethon/tl/core/rpcresult.py
Normal file
23
telethon/tl/core/rpcresult.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from .gzippacked import GzipPacked
|
||||
from ..types import RpcError
|
||||
|
||||
|
||||
class RpcResult:
|
||||
CONSTRUCTOR_ID = 0xf35c6d01
|
||||
|
||||
def __init__(self, req_msg_id, body, error):
|
||||
self.req_msg_id = req_msg_id
|
||||
self.body = body
|
||||
self.error = error
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
msg_id = reader.read_long()
|
||||
inner_code = reader.read_int(signed=False)
|
||||
if inner_code == RpcError.CONSTRUCTOR_ID:
|
||||
return RpcResult(msg_id, None, RpcError.from_reader(reader))
|
||||
if inner_code == GzipPacked.CONSTRUCTOR_ID:
|
||||
return RpcResult(msg_id, GzipPacked.from_reader(reader).data, None)
|
||||
|
||||
reader.seek(-4)
|
||||
return RpcResult(msg_id, reader.read(), None)
|
59
telethon/tl/core/tlmessage.py
Normal file
59
telethon/tl/core/tlmessage.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import asyncio
|
||||
import struct
|
||||
|
||||
from .gzippacked import GzipPacked
|
||||
from .. import TLObject
|
||||
from ..functions import InvokeAfterMsgRequest
|
||||
|
||||
|
||||
class TLMessage(TLObject):
|
||||
"""
|
||||
https://core.telegram.org/mtproto/service_messages#simple-container.
|
||||
|
||||
Messages are what's ultimately sent to Telegram:
|
||||
message msg_id:long seqno:int bytes:int body:bytes = Message;
|
||||
|
||||
Each message has its own unique identifier, and the body is simply
|
||||
the serialized request that should be executed on the server. Then
|
||||
Telegram will, at some point, respond with the result for this msg.
|
||||
|
||||
Thus it makes sense that requests and their result are bound to a
|
||||
sent `TLMessage`, and this result can be represented as a `Future`
|
||||
that will eventually be set with either a result, error or cancelled.
|
||||
"""
|
||||
def __init__(self, msg_id, seq_no, obj=None, after_id=0):
|
||||
super().__init__()
|
||||
self.msg_id = msg_id
|
||||
self.seq_no = seq_no
|
||||
self.obj = obj
|
||||
self.container_msg_id = None
|
||||
self.future = asyncio.Future()
|
||||
|
||||
# After which message ID this one should run. We do this so
|
||||
# InvokeAfterMsgRequest is transparent to the user and we can
|
||||
# easily invoke after while confirming the original request.
|
||||
self.after_id = after_id
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'msg_id': self.msg_id,
|
||||
'seq_no': self.seq_no,
|
||||
'obj': self.obj,
|
||||
'container_msg_id': self.container_msg_id,
|
||||
'after_id': self.after_id
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
if self.after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(self.obj)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(
|
||||
InvokeAfterMsgRequest(self.after_id, self.obj))
|
||||
|
||||
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
|
@ -88,12 +88,13 @@ class Dialog:
|
|||
)
|
||||
self.is_channel = isinstance(self.entity, types.Channel)
|
||||
|
||||
def send_message(self, *args, **kwargs):
|
||||
async def send_message(self, *args, **kwargs):
|
||||
"""
|
||||
Sends a message to this dialog. This is just a wrapper around
|
||||
``client.send_message(dialog.input_entity, *args, **kwargs)``.
|
||||
"""
|
||||
return self._client.send_message(self.input_entity, *args, **kwargs)
|
||||
return await self._client.send_message(
|
||||
self.input_entity, *args, **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
|
@ -47,18 +47,18 @@ class Draft:
|
|||
return cls(client=client, peer=update.peer, draft=update.draft)
|
||||
|
||||
@property
|
||||
def entity(self):
|
||||
async def entity(self):
|
||||
"""
|
||||
The entity that belongs to this dialog (user, chat or channel).
|
||||
"""
|
||||
return self._client.get_entity(self._peer)
|
||||
return await self._client.get_entity(self._peer)
|
||||
|
||||
@property
|
||||
def input_entity(self):
|
||||
async def input_entity(self):
|
||||
"""
|
||||
Input version of the entity.
|
||||
"""
|
||||
return self._client.get_input_entity(self._peer)
|
||||
return await self._client.get_input_entity(self._peer)
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
|
@ -83,8 +83,9 @@ class Draft:
|
|||
"""
|
||||
return not self._text
|
||||
|
||||
def set_message(self, text=None, reply_to=0, parse_mode=Default,
|
||||
link_preview=None):
|
||||
async def set_message(
|
||||
self, text=None, reply_to=0, parse_mode=Default,
|
||||
link_preview=None):
|
||||
"""
|
||||
Changes the draft message on the Telegram servers. The changes are
|
||||
reflected in this object.
|
||||
|
@ -110,8 +111,10 @@ class Draft:
|
|||
if link_preview is None:
|
||||
link_preview = self.link_preview
|
||||
|
||||
raw_text, entities = self._client._parse_message_text(text, parse_mode)
|
||||
result = self._client(SaveDraftRequest(
|
||||
raw_text, entities =\
|
||||
await self._client._parse_message_text(text, parse_mode)
|
||||
|
||||
result = await self._client(SaveDraftRequest(
|
||||
peer=self._peer,
|
||||
message=raw_text,
|
||||
no_webpage=not link_preview,
|
||||
|
@ -128,22 +131,22 @@ class Draft:
|
|||
|
||||
return result
|
||||
|
||||
def send(self, clear=True, parse_mode=Default):
|
||||
async def send(self, clear=True, parse_mode=Default):
|
||||
"""
|
||||
Sends the contents of this draft to the dialog. This is just a
|
||||
wrapper around ``send_message(dialog.input_entity, *args, **kwargs)``.
|
||||
"""
|
||||
self._client.send_message(self._peer, self.text,
|
||||
reply_to=self.reply_to_msg_id,
|
||||
link_preview=self.link_preview,
|
||||
parse_mode=parse_mode,
|
||||
clear_draft=clear)
|
||||
await self._client.send_message(
|
||||
self._peer, self.text, reply_to=self.reply_to_msg_id,
|
||||
link_preview=self.link_preview, parse_mode=parse_mode,
|
||||
clear_draft=clear
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
async def delete(self):
|
||||
"""
|
||||
Deletes this draft, and returns ``True`` on success.
|
||||
"""
|
||||
return self.set_message(text='')
|
||||
return await self.set_message(text='')
|
||||
|
||||
def to_dict(self):
|
||||
try:
|
||||
|
|
|
@ -134,14 +134,15 @@ class Message:
|
|||
if isinstance(self.original_message, types.MessageService):
|
||||
return self.original_message.action
|
||||
|
||||
def _reload_message(self):
|
||||
async def _reload_message(self):
|
||||
"""
|
||||
Re-fetches this message to reload the sender and chat entities,
|
||||
along with their input versions.
|
||||
"""
|
||||
try:
|
||||
chat = self.input_chat if self.is_channel else None
|
||||
msg = self._client.get_messages(chat, ids=self.original_message.id)
|
||||
chat = await self.input_chat if self.is_channel else None
|
||||
msg = await self._client.get_messages(
|
||||
chat, ids=self.original_message.id)
|
||||
except ValueError:
|
||||
return # We may not have the input chat/get message failed
|
||||
if not msg:
|
||||
|
@ -153,7 +154,7 @@ class Message:
|
|||
self._input_chat = msg._input_chat
|
||||
|
||||
@property
|
||||
def sender(self):
|
||||
async def sender(self):
|
||||
"""
|
||||
This (:tl:`User`) may make an API call the first time to get
|
||||
the most up to date version of the sender (mostly when the event
|
||||
|
@ -163,22 +164,24 @@ class Message:
|
|||
"""
|
||||
if self._sender is None:
|
||||
try:
|
||||
self._sender = self._client.get_entity(self.input_sender)
|
||||
self._sender =\
|
||||
await self._client.get_entity(await self.input_sender)
|
||||
except ValueError:
|
||||
self._reload_message()
|
||||
await self._reload_message()
|
||||
return self._sender
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
async def chat(self):
|
||||
if self._chat is None:
|
||||
try:
|
||||
self._chat = self._client.get_entity(self.input_chat)
|
||||
self._chat =\
|
||||
await self._client.get_entity(await self.input_chat)
|
||||
except ValueError:
|
||||
self._reload_message()
|
||||
await self._reload_message()
|
||||
return self._chat
|
||||
|
||||
@property
|
||||
def input_sender(self):
|
||||
async def input_sender(self):
|
||||
"""
|
||||
This (:tl:`InputPeer`) is the input version of the user who
|
||||
sent the message. Similarly to `input_chat`, this doesn't have
|
||||
|
@ -194,14 +197,14 @@ class Message:
|
|||
self._input_sender = get_input_peer(self._sender)
|
||||
else:
|
||||
try:
|
||||
self._input_sender = self._client.get_input_entity(
|
||||
self._input_sender = await self._client.get_input_entity(
|
||||
self.original_message.from_id)
|
||||
except ValueError:
|
||||
self._reload_message()
|
||||
await self._reload_message()
|
||||
return self._input_sender
|
||||
|
||||
@property
|
||||
def input_chat(self):
|
||||
async def input_chat(self):
|
||||
"""
|
||||
This (:tl:`InputPeer`) is the input version of the chat where the
|
||||
message was sent. Similarly to `input_sender`, this doesn't have
|
||||
|
@ -214,14 +217,14 @@ class Message:
|
|||
if self._input_chat is None:
|
||||
if self._chat is None:
|
||||
try:
|
||||
self._chat = self._client.get_input_entity(
|
||||
self._chat = await self._client.get_input_entity(
|
||||
self.original_message.to_id)
|
||||
except ValueError:
|
||||
# There's a chance that the chat is a recent new dialog.
|
||||
# The input chat cannot rely on ._reload_message() because
|
||||
# said method may need the input chat.
|
||||
target = self.chat_id
|
||||
for d in self._client.iter_dialogs(100):
|
||||
async for d in self._client.iter_dialogs(100):
|
||||
if d.id == target:
|
||||
self._chat = d.entity
|
||||
break
|
||||
|
@ -269,24 +272,26 @@ class Message:
|
|||
return bool(self.original_message.reply_to_msg_id)
|
||||
|
||||
@property
|
||||
def buttons(self):
|
||||
async def buttons(self):
|
||||
"""
|
||||
Returns a matrix (list of lists) containing all buttons of the message
|
||||
as `telethon.tl.custom.messagebutton.MessageButton` instances.
|
||||
"""
|
||||
if self._buttons is None and self.original_message.reply_markup:
|
||||
sender = await self.input_sender
|
||||
chat = await self.input_chat
|
||||
if isinstance(self.original_message.reply_markup, (
|
||||
types.ReplyInlineMarkup, types.ReplyKeyboardMarkup)):
|
||||
self._buttons = [[
|
||||
MessageButton(self._client, button, self.input_sender,
|
||||
self.input_chat, self.original_message.id)
|
||||
MessageButton(self._client, button, sender, chat,
|
||||
self.original_message.id)
|
||||
for button in row.buttons
|
||||
] for row in self.original_message.reply_markup.rows]
|
||||
self._buttons_flat = [x for row in self._buttons for x in row]
|
||||
return self._buttons
|
||||
|
||||
@property
|
||||
def button_count(self):
|
||||
async def button_count(self):
|
||||
"""
|
||||
Returns the total button count.
|
||||
"""
|
||||
|
@ -386,7 +391,7 @@ class Message:
|
|||
return self.original_message.out
|
||||
|
||||
@property
|
||||
def reply_message(self):
|
||||
async def reply_message(self):
|
||||
"""
|
||||
The `telethon.tl.custom.message.Message` that this message is replying
|
||||
to, or ``None``.
|
||||
|
@ -397,15 +402,15 @@ class Message:
|
|||
if self._reply_message is None:
|
||||
if not self.original_message.reply_to_msg_id:
|
||||
return None
|
||||
self._reply_message = self._client.get_messages(
|
||||
self.input_chat if self.is_channel else None,
|
||||
self._reply_message = await self._client.get_messages(
|
||||
await self.input_chat if self.is_channel else None,
|
||||
ids=self.original_message.reply_to_msg_id
|
||||
)
|
||||
|
||||
return self._reply_message
|
||||
|
||||
@property
|
||||
def fwd_from_entity(self):
|
||||
async def fwd_from_entity(self):
|
||||
"""
|
||||
If the :tl:`Message` is a forwarded message, returns the :tl:`User`
|
||||
or :tl:`Channel` who originally sent the message, or ``None``.
|
||||
|
@ -414,32 +419,33 @@ class Message:
|
|||
if getattr(self.original_message, 'fwd_from', None):
|
||||
fwd = self.original_message.fwd_from
|
||||
if fwd.from_id:
|
||||
self._fwd_from_entity = self._client.get_entity(
|
||||
fwd.from_id)
|
||||
self._fwd_from_entity =\
|
||||
await self._client.get_entity(fwd.from_id)
|
||||
elif fwd.channel_id:
|
||||
self._fwd_from_entity = self._client.get_entity(
|
||||
self._fwd_from_entity = await self._client.get_entity(
|
||||
get_peer_id(types.PeerChannel(fwd.channel_id)))
|
||||
return self._fwd_from_entity
|
||||
|
||||
def respond(self, *args, **kwargs):
|
||||
async def respond(self, *args, **kwargs):
|
||||
"""
|
||||
Responds to the message (not as a reply). Shorthand for
|
||||
`telethon.telegram_client.TelegramClient.send_message` with
|
||||
``entity`` already set.
|
||||
"""
|
||||
return self._client.send_message(self.input_chat, *args, **kwargs)
|
||||
return await self._client.send_message(
|
||||
await self.input_chat, *args, **kwargs)
|
||||
|
||||
def reply(self, *args, **kwargs):
|
||||
async def reply(self, *args, **kwargs):
|
||||
"""
|
||||
Replies to the message (as a reply). Shorthand for
|
||||
`telethon.telegram_client.TelegramClient.send_message` with
|
||||
both ``entity`` and ``reply_to`` already set.
|
||||
"""
|
||||
kwargs['reply_to'] = self.original_message.id
|
||||
return self._client.send_message(self.original_message.to_id,
|
||||
*args, **kwargs)
|
||||
return await self._client.send_message(
|
||||
await self.input_chat, *args, **kwargs)
|
||||
|
||||
def forward_to(self, *args, **kwargs):
|
||||
async def forward_to(self, *args, **kwargs):
|
||||
"""
|
||||
Forwards the message. Shorthand for
|
||||
`telethon.telegram_client.TelegramClient.forward_messages` with
|
||||
|
@ -450,10 +456,10 @@ class Message:
|
|||
`telethon.telegram_client.TelegramClient` instance directly.
|
||||
"""
|
||||
kwargs['messages'] = self.original_message.id
|
||||
kwargs['from_peer'] = self.input_chat
|
||||
return self._client.forward_messages(*args, **kwargs)
|
||||
kwargs['from_peer'] = await self.input_chat
|
||||
return await self._client.forward_messages(*args, **kwargs)
|
||||
|
||||
def edit(self, *args, **kwargs):
|
||||
async def edit(self, *args, **kwargs):
|
||||
"""
|
||||
Edits the message iff it's outgoing. Shorthand for
|
||||
`telethon.telegram_client.TelegramClient.edit_message` with
|
||||
|
@ -471,10 +477,10 @@ class Message:
|
|||
if self.original_message.to_id.user_id != me.user_id:
|
||||
return None
|
||||
|
||||
return self._client.edit_message(
|
||||
self.input_chat, self.original_message, *args, **kwargs)
|
||||
return await self._client.edit_message(
|
||||
await self.input_chat, self.original_message, *args, **kwargs)
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
async def delete(self, *args, **kwargs):
|
||||
"""
|
||||
Deletes the message. You're responsible for checking whether you
|
||||
have the permission to do so, or to except the error otherwise.
|
||||
|
@ -486,17 +492,17 @@ class Message:
|
|||
this `delete` method. Use a
|
||||
`telethon.telegram_client.TelegramClient` instance directly.
|
||||
"""
|
||||
return self._client.delete_messages(
|
||||
self.input_chat, [self.original_message], *args, **kwargs)
|
||||
return await self._client.delete_messages(
|
||||
await self.input_chat, [self.original_message], *args, **kwargs)
|
||||
|
||||
def download_media(self, *args, **kwargs):
|
||||
async def download_media(self, *args, **kwargs):
|
||||
"""
|
||||
Downloads the media contained in the message, if any.
|
||||
`telethon.telegram_client.TelegramClient.download_media` with
|
||||
the ``message`` already set.
|
||||
"""
|
||||
return self._client.download_media(self.original_message,
|
||||
*args, **kwargs)
|
||||
return await self._client.download_media(
|
||||
self.original_message, *args, **kwargs)
|
||||
|
||||
def get_entities_text(self, cls=None):
|
||||
"""
|
||||
|
@ -525,12 +531,10 @@ class Message:
|
|||
self.original_message.entities)
|
||||
return list(zip(self.original_message.entities, texts))
|
||||
|
||||
def click(self, i=None, j=None, *, text=None, filter=None):
|
||||
async def click(self, i=None, j=None, *, text=None, filter=None):
|
||||
"""
|
||||
Clicks the inline keyboard button of the message, if any.
|
||||
|
||||
If the message has a non-inline keyboard, clicking it will
|
||||
send the message, switch to inline, or open its URL.
|
||||
Calls `telethon.tl.custom.messagebutton.MessageButton.click`
|
||||
for the specified button.
|
||||
|
||||
Does nothing if the message has no buttons.
|
||||
|
||||
|
@ -571,32 +575,32 @@ class Message:
|
|||
if sum(int(x is not None) for x in (i, text, filter)) >= 2:
|
||||
raise ValueError('You can only set either of i, text or filter')
|
||||
|
||||
if not self.buttons:
|
||||
if not await self.buttons:
|
||||
return # Accessing the property sets self._buttons[_flat]
|
||||
|
||||
if text is not None:
|
||||
if callable(text):
|
||||
for button in self._buttons_flat:
|
||||
if text(button.text):
|
||||
return button.click()
|
||||
return await button.click()
|
||||
else:
|
||||
for button in self._buttons_flat:
|
||||
if button.text == text:
|
||||
return button.click()
|
||||
return await button.click()
|
||||
return
|
||||
|
||||
if filter is not None:
|
||||
for button in self._buttons_flat:
|
||||
if filter(button):
|
||||
return button.click()
|
||||
return await button.click()
|
||||
return
|
||||
|
||||
if i is None:
|
||||
i = 0
|
||||
if j is None:
|
||||
return self._buttons_flat[i].click()
|
||||
return await self._buttons_flat[i].click()
|
||||
else:
|
||||
return self._buttons[i][j].click()
|
||||
return await self._buttons[i][j].click()
|
||||
|
||||
|
||||
class _CustomMessage(Message, types.Message):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .. import types, functions
|
||||
from ...errors import BotTimeout
|
||||
import webbrowser
|
||||
|
||||
|
||||
|
@ -51,23 +52,37 @@ class MessageButton:
|
|||
if isinstance(self.button, types.KeyboardButtonUrl):
|
||||
return self.button.url
|
||||
|
||||
def click(self):
|
||||
async def click(self):
|
||||
"""
|
||||
Clicks the inline keyboard button of the message, if any.
|
||||
Emulates the behaviour of clicking this button.
|
||||
|
||||
If the message has a non-inline keyboard, clicking it will
|
||||
send the message, switch to inline, or open its URL.
|
||||
If it's a normal :tl:`KeyboardButton` with text, a message will be
|
||||
sent, and the sent `telethon.tl.custom.message.Message` returned.
|
||||
|
||||
If it's an inline :tl:`KeyboardButtonCallback` with text and data,
|
||||
it will be "clicked" and the :tl:`BotCallbackAnswer` returned.
|
||||
|
||||
If it's an inline :tl:`KeyboardButtonSwitchInline` button, the
|
||||
:tl:`StartBotRequest` will be invoked and the resulting updates
|
||||
returned.
|
||||
|
||||
If it's a :tl:`KeyboardButtonUrl`, the URL of the button will
|
||||
be passed to ``webbrowser.open`` and return ``True`` on success.
|
||||
"""
|
||||
if isinstance(self.button, types.KeyboardButton):
|
||||
return self._client.send_message(
|
||||
return await self._client.send_message(
|
||||
self._chat, self.button.text, reply_to=self._msg_id)
|
||||
elif isinstance(self.button, types.KeyboardButtonCallback):
|
||||
return self._client(functions.messages.GetBotCallbackAnswerRequest(
|
||||
req = functions.messages.GetBotCallbackAnswerRequest(
|
||||
peer=self._chat, msg_id=self._msg_id, data=self.button.data
|
||||
), retries=1)
|
||||
)
|
||||
try:
|
||||
return await self._client(req)
|
||||
except BotTimeout:
|
||||
return None
|
||||
elif isinstance(self.button, types.KeyboardButtonSwitchInline):
|
||||
return self._client(functions.messages.StartBotRequest(
|
||||
return await self._client(functions.messages.StartBotRequest(
|
||||
bot=self._from, peer=self._chat, start_param=self.button.query
|
||||
), retries=1)
|
||||
))
|
||||
elif isinstance(self.button, types.KeyboardButtonUrl):
|
||||
return webbrowser.open(self.button.url)
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
import struct
|
||||
|
||||
from . import TLObject
|
||||
|
||||
|
||||
class MessageContainer(TLObject):
|
||||
CONSTRUCTOR_ID = 0x73f1f8dc
|
||||
|
||||
def __init__(self, messages):
|
||||
super().__init__()
|
||||
self.content_related = False
|
||||
self.messages = messages
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'content_related': self.content_related,
|
||||
'messages':
|
||||
([] if self.messages is None else [
|
||||
None if x is None else x.to_dict() for x in self.messages
|
||||
]) if recursive else self.messages,
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
) + b''.join(bytes(m) for m in self.messages)
|
||||
|
||||
@staticmethod
|
||||
def iter_read(reader):
|
||||
reader.read_int(signed=False) # code
|
||||
size = reader.read_int()
|
||||
for _ in range(size):
|
||||
inner_msg_id = reader.read_long()
|
||||
inner_sequence = reader.read_int()
|
||||
inner_length = reader.read_int()
|
||||
yield inner_msg_id, inner_sequence, inner_length
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
|
@ -1,44 +0,0 @@
|
|||
import struct
|
||||
|
||||
from . import TLObject, GzipPacked
|
||||
from ..tl.functions import InvokeAfterMsgRequest
|
||||
|
||||
|
||||
class TLMessage(TLObject):
|
||||
"""https://core.telegram.org/mtproto/service_messages#simple-container"""
|
||||
def __init__(self, session, request, after_id=None):
|
||||
super().__init__()
|
||||
del self.content_related
|
||||
self.msg_id = session.get_new_msg_id()
|
||||
self.seq_no = session.generate_sequence(request.content_related)
|
||||
self.request = request
|
||||
self.container_msg_id = None
|
||||
|
||||
# After which message ID this one should run. We do this so
|
||||
# InvokeAfterMsgRequest is transparent to the user and we can
|
||||
# easily invoke after while confirming the original request.
|
||||
self.after_id = after_id
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'msg_id': self.msg_id,
|
||||
'seq_no': self.seq_no,
|
||||
'request': self.request,
|
||||
'container_msg_id': self.container_msg_id,
|
||||
'after_id': self.after_id
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
if self.after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(self.request)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(
|
||||
InvokeAfterMsgRequest(self.after_id, self.request))
|
||||
|
||||
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
|
@ -1,45 +1,13 @@
|
|||
import struct
|
||||
from datetime import datetime, date
|
||||
from threading import Event
|
||||
|
||||
|
||||
class TLObject:
|
||||
def __init__(self):
|
||||
self.rpc_error = None
|
||||
self.result = None
|
||||
|
||||
# These should be overrode
|
||||
self.content_related = False # Only requests/functions/queries are
|
||||
|
||||
# Internal parameter to tell pickler in which state Event object was
|
||||
self._event_is_set = False
|
||||
self._set_event()
|
||||
|
||||
def _set_event(self):
|
||||
self.confirm_received = Event()
|
||||
|
||||
# Set Event state to 'set' if needed
|
||||
if self._event_is_set:
|
||||
self.confirm_received.set()
|
||||
|
||||
def __getstate__(self):
|
||||
# Save state of the Event object
|
||||
self._event_is_set = self.confirm_received.is_set()
|
||||
|
||||
# Exclude Event object from dict and return new state
|
||||
new_dct = dict(self.__dict__)
|
||||
del new_dct["confirm_received"]
|
||||
return new_dct
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self._set_event()
|
||||
|
||||
# These should not be overrode
|
||||
@staticmethod
|
||||
def pretty_format(obj, indent=None):
|
||||
"""Pretty formats the given object as a string which is returned.
|
||||
If indent is None, a single line will be returned.
|
||||
"""
|
||||
Pretty formats the given object as a string which is returned.
|
||||
If indent is None, a single line will be returned.
|
||||
"""
|
||||
if indent is None:
|
||||
if isinstance(obj, TLObject):
|
||||
|
@ -163,10 +131,6 @@ class TLObject:
|
|||
|
||||
raise TypeError('Cannot interpret "{}" as a date.'.format(dt))
|
||||
|
||||
# These are nearly always the same for all subclasses
|
||||
def on_response(self, reader):
|
||||
self.result = reader.tgread_object()
|
||||
|
||||
def __eq__(self, o):
|
||||
return isinstance(o, type(self)) and self.to_dict() == o.to_dict()
|
||||
|
||||
|
@ -179,16 +143,24 @@ class TLObject:
|
|||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
||||
|
||||
# These should be overrode
|
||||
def resolve(self, client, utils):
|
||||
pass
|
||||
|
||||
def to_dict(self):
|
||||
return {}
|
||||
raise NotImplementedError
|
||||
|
||||
def __bytes__(self):
|
||||
return b''
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
return TLObject()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TLRequest(TLObject):
|
||||
"""
|
||||
Represents a content-related `TLObject` (a request that can be sent).
|
||||
"""
|
||||
@staticmethod
|
||||
def read_result(reader):
|
||||
return reader.tgread_object()
|
||||
|
||||
async def resolve(self, client, utils):
|
||||
pass
|
||||
|
|
|
@ -17,17 +17,7 @@ class UpdateState:
|
|||
"""
|
||||
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
|
||||
|
||||
def __init__(self, workers=None):
|
||||
"""
|
||||
:param workers: This integer parameter has three possible cases:
|
||||
workers is None: Updates will *not* be stored on self.
|
||||
workers = 0: Another thread is responsible for calling self.poll()
|
||||
workers > 0: 'workers' background threads will be spawned, any
|
||||
any of them will invoke the self.handler.
|
||||
"""
|
||||
self._workers = workers
|
||||
self._worker_threads = []
|
||||
|
||||
def __init__(self):
|
||||
self.handler = None
|
||||
self._updates_lock = RLock()
|
||||
self._updates = Queue()
|
||||
|
@ -50,66 +40,6 @@ class UpdateState:
|
|||
except Empty:
|
||||
return None
|
||||
|
||||
def get_workers(self):
|
||||
return self._workers
|
||||
|
||||
def set_workers(self, n):
|
||||
"""Changes the number of workers running.
|
||||
If 'n is None', clears all pending updates from memory.
|
||||
"""
|
||||
if n is None:
|
||||
self.stop_workers()
|
||||
else:
|
||||
self._workers = n
|
||||
self.setup_workers()
|
||||
|
||||
workers = property(fget=get_workers, fset=set_workers)
|
||||
|
||||
def stop_workers(self):
|
||||
"""
|
||||
Waits for all the worker threads to stop.
|
||||
"""
|
||||
# Put dummy ``None`` objects so that they don't need to timeout.
|
||||
n = self._workers
|
||||
self._workers = None
|
||||
if n:
|
||||
with self._updates_lock:
|
||||
for _ in range(n):
|
||||
self._updates.put(None)
|
||||
|
||||
for t in self._worker_threads:
|
||||
t.join()
|
||||
|
||||
self._worker_threads.clear()
|
||||
self._workers = n
|
||||
|
||||
def setup_workers(self):
|
||||
if self._worker_threads or not self._workers:
|
||||
# There already are workers, or workers is None or 0. Do nothing.
|
||||
return
|
||||
|
||||
for i in range(self._workers):
|
||||
thread = Thread(
|
||||
target=UpdateState._worker_loop,
|
||||
name='UpdateWorker{}'.format(i),
|
||||
daemon=True,
|
||||
args=(self, i)
|
||||
)
|
||||
self._worker_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def _worker_loop(self, wid):
|
||||
while self._workers is not None:
|
||||
try:
|
||||
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
|
||||
if update and self.handler:
|
||||
self.handler(update)
|
||||
except StopIteration:
|
||||
break
|
||||
except:
|
||||
# We don't want to crash a worker thread due to any reason
|
||||
__log__.exception('Unhandled exception on worker %d', wid)
|
||||
|
||||
def get_update_state(self, entity_id):
|
||||
"""Gets the updates.State corresponding to the given entity or 0."""
|
||||
return self._state
|
||||
|
@ -118,35 +48,32 @@ class UpdateState:
|
|||
"""Processes an update object. This method is normally called by
|
||||
the library itself.
|
||||
"""
|
||||
if self._workers is None:
|
||||
return # No processing needs to be done if nobody's working
|
||||
if isinstance(update, tl.updates.State):
|
||||
__log__.debug('Saved new updates state')
|
||||
self._state = update
|
||||
return # Nothing else to be done
|
||||
|
||||
with self._updates_lock:
|
||||
if isinstance(update, tl.updates.State):
|
||||
__log__.debug('Saved new updates state')
|
||||
self._state = update
|
||||
return # Nothing else to be done
|
||||
if hasattr(update, 'pts'):
|
||||
self._state.pts = update.pts
|
||||
|
||||
if hasattr(update, 'pts'):
|
||||
self._state.pts = update.pts
|
||||
# After running the script for over an hour and receiving over
|
||||
# 1000 updates, the only duplicates received were users going
|
||||
# online or offline. We can trust the server until new reports.
|
||||
# This should only be used as read-only.
|
||||
if isinstance(update, tl.UpdateShort):
|
||||
update.update._entities = {}
|
||||
self._updates.put(update.update)
|
||||
|
||||
# After running the script for over an hour and receiving over
|
||||
# 1000 updates, the only duplicates received were users going
|
||||
# online or offline. We can trust the server until new reports.
|
||||
# This should only be used as read-only.
|
||||
if isinstance(update, tl.UpdateShort):
|
||||
update.update._entities = {}
|
||||
self._updates.put(update.update)
|
||||
# Expand "Updates" into "Update", and pass these to callbacks.
|
||||
# Since .users and .chats have already been processed, we
|
||||
# don't need to care about those either.
|
||||
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
||||
entities = {utils.get_peer_id(x): x for x in
|
||||
itertools.chain(update.users, update.chats)}
|
||||
for u in update.updates:
|
||||
u._entities = entities
|
||||
self._updates.put(u)
|
||||
# TODO Handle "tl.UpdatesTooLong"
|
||||
else:
|
||||
update._entities = {}
|
||||
self._updates.put(update)
|
||||
# Expand "Updates" into "Update", and pass these to callbacks.
|
||||
# Since .users and .chats have already been processed, we
|
||||
# don't need to care about those either.
|
||||
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
||||
entities = {utils.get_peer_id(x): x for x in
|
||||
itertools.chain(update.users, update.chats)}
|
||||
for u in update.updates:
|
||||
u._entities = entities
|
||||
self._updates.put(u)
|
||||
# TODO Handle "tl.UpdatesTooLong"
|
||||
else:
|
||||
update._entities = {}
|
||||
self._updates.put(update)
|
||||
|
|
|
@ -12,6 +12,7 @@ import types
|
|||
from collections import UserList
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from .extensions import markdown, html
|
||||
from .tl import TLObject
|
||||
from .tl.types import (
|
||||
Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull,
|
||||
|
@ -402,6 +403,60 @@ def get_input_message(message):
|
|||
_raise_cast_fail(message, 'InputMedia')
|
||||
|
||||
|
||||
def get_message_id(message):
|
||||
"""Sanitizes the 'reply_to' parameter a user may send"""
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
if isinstance(message, int):
|
||||
return message
|
||||
|
||||
if hasattr(message, 'original_message'):
|
||||
return message.original_message.id
|
||||
|
||||
try:
|
||||
if message.SUBCLASS_OF_ID == 0x790009e3:
|
||||
# hex(crc32(b'Message')) = 0x790009e3
|
||||
return message.id
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
raise TypeError('Invalid message type: {}'.format(type(message)))
|
||||
|
||||
|
||||
def sanitize_parse_mode(mode):
|
||||
"""
|
||||
Converts the given parse mode into an object with
|
||||
``parse`` and ``unparse`` callable properties.
|
||||
"""
|
||||
if not mode:
|
||||
return None
|
||||
|
||||
if callable(mode):
|
||||
class CustomMode:
|
||||
@staticmethod
|
||||
def unparse(text, entities):
|
||||
raise NotImplementedError
|
||||
|
||||
CustomMode.parse = mode
|
||||
return CustomMode
|
||||
elif (all(hasattr(mode, x) for x in ('parse', 'unparse'))
|
||||
and all(callable(x) for x in (mode.parse, mode.unparse))):
|
||||
return mode
|
||||
elif isinstance(mode, str):
|
||||
try:
|
||||
return {
|
||||
'md': markdown,
|
||||
'markdown': markdown,
|
||||
'htm': html,
|
||||
'html': html
|
||||
}[mode.lower()]
|
||||
except KeyError:
|
||||
raise ValueError('Unknown parse mode {}'.format(mode))
|
||||
else:
|
||||
raise TypeError('Invalid parse mode type {}'.format(mode))
|
||||
|
||||
|
||||
def get_input_location(location):
|
||||
"""Similar to :meth:`get_input_peer`, but for input messages."""
|
||||
try:
|
||||
|
|
|
@ -49,7 +49,7 @@ new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long =
|
|||
//message msg_id:long seqno:int bytes:int body:bytes = Message;
|
||||
//msg_copy#e06046b2 orig_message:Message = MessageCopy;
|
||||
|
||||
gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
//gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
|
||||
msgs_ack#62d6b459 msg_ids:Vector<long> = MsgsAck;
|
||||
|
||||
|
|
|
@ -14,10 +14,15 @@ AUTO_GEN_NOTICE = \
|
|||
|
||||
|
||||
AUTO_CASTS = {
|
||||
'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))',
|
||||
'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))',
|
||||
'InputUser': 'utils.get_input_user(client.get_input_entity({}))',
|
||||
'InputDialogPeer': 'utils.get_input_dialog(client.get_input_entity({}))',
|
||||
'InputPeer':
|
||||
'utils.get_input_peer(await client.get_input_entity({}))',
|
||||
'InputChannel':
|
||||
'utils.get_input_channel(await client.get_input_entity({}))',
|
||||
'InputUser':
|
||||
'utils.get_input_user(await client.get_input_entity({}))',
|
||||
'InputDialogPeer':
|
||||
'utils.get_input_dialog(await client.get_input_entity({}))',
|
||||
|
||||
'InputMedia': 'utils.get_input_media({})',
|
||||
'InputPhoto': 'utils.get_input_photo({})',
|
||||
'InputMessage': 'utils.get_input_message({})'
|
||||
|
@ -27,7 +32,8 @@ BASE_TYPES = ('string', 'bytes', 'int', 'long', 'int128',
|
|||
'int256', 'double', 'Bool', 'true', 'date')
|
||||
|
||||
|
||||
def _write_modules(out_dir, depth, namespace_tlobjects, type_constructors):
|
||||
def _write_modules(
|
||||
out_dir, depth, kind, namespace_tlobjects, type_constructors):
|
||||
# namespace_tlobjects: {'namespace', [TLObject]}
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
for ns, tlobjects in namespace_tlobjects.items():
|
||||
|
@ -36,7 +42,7 @@ def _write_modules(out_dir, depth, namespace_tlobjects, type_constructors):
|
|||
SourceBuilder(f) as builder:
|
||||
builder.writeln(AUTO_GEN_NOTICE)
|
||||
|
||||
builder.writeln('from {}.tl.tlobject import TLObject', '.' * depth)
|
||||
builder.writeln('from {}.tl.tlobject import {}', '.' * depth, kind)
|
||||
builder.writeln('from typing import Optional, List, '
|
||||
'Union, TYPE_CHECKING')
|
||||
|
||||
|
@ -119,7 +125,7 @@ def _write_modules(out_dir, depth, namespace_tlobjects, type_constructors):
|
|||
|
||||
# Generate the class for every TLObject
|
||||
for t in tlobjects:
|
||||
_write_source_code(t, builder, type_constructors)
|
||||
_write_source_code(t, kind, builder, type_constructors)
|
||||
builder.current_indent = 0
|
||||
|
||||
# Write the type definitions generated earlier.
|
||||
|
@ -128,7 +134,7 @@ def _write_modules(out_dir, depth, namespace_tlobjects, type_constructors):
|
|||
builder.writeln(line)
|
||||
|
||||
|
||||
def _write_source_code(tlobject, builder, type_constructors):
|
||||
def _write_source_code(tlobject, kind, builder, type_constructors):
|
||||
"""
|
||||
Writes the source code corresponding to the given TLObject
|
||||
by making use of the ``builder`` `SourceBuilder`.
|
||||
|
@ -137,18 +143,18 @@ def _write_source_code(tlobject, builder, type_constructors):
|
|||
the ``Type: [Constructors]`` must be given for proper
|
||||
importing and documentation strings.
|
||||
"""
|
||||
_write_class_init(tlobject, type_constructors, builder)
|
||||
_write_class_init(tlobject, kind, type_constructors, builder)
|
||||
_write_resolve(tlobject, builder)
|
||||
_write_to_dict(tlobject, builder)
|
||||
_write_to_bytes(tlobject, builder)
|
||||
_write_from_reader(tlobject, builder)
|
||||
_write_on_response(tlobject, builder)
|
||||
_write_read_result(tlobject, builder)
|
||||
|
||||
|
||||
def _write_class_init(tlobject, type_constructors, builder):
|
||||
def _write_class_init(tlobject, kind, type_constructors, builder):
|
||||
builder.writeln()
|
||||
builder.writeln()
|
||||
builder.writeln('class {}(TLObject):', tlobject.class_name)
|
||||
builder.writeln('class {}({}):', tlobject.class_name, kind)
|
||||
|
||||
# Class-level variable to store its Telegram's constructor ID
|
||||
builder.writeln('CONSTRUCTOR_ID = {:#x}', tlobject.id)
|
||||
|
@ -160,46 +166,39 @@ def _write_class_init(tlobject, type_constructors, builder):
|
|||
args = [(a.name if not a.is_flag and not a.can_be_inferred
|
||||
else '{}=None'.format(a.name)) for a in tlobject.real_args]
|
||||
|
||||
# Write the __init__ function
|
||||
# Write the __init__ function if it has any argument
|
||||
if not tlobject.real_args:
|
||||
return
|
||||
|
||||
builder.writeln('def __init__({}):', ', '.join(['self'] + args))
|
||||
if tlobject.real_args:
|
||||
# Write the docstring, to know the type of the args
|
||||
builder.writeln('"""')
|
||||
for arg in tlobject.real_args:
|
||||
if not arg.flag_indicator:
|
||||
builder.writeln(':param {} {}:', arg.type_hint(), arg.name)
|
||||
builder.current_indent -= 1 # It will auto-indent (':')
|
||||
# Write the docstring, to know the type of the args
|
||||
builder.writeln('"""')
|
||||
for arg in tlobject.real_args:
|
||||
if not arg.flag_indicator:
|
||||
builder.writeln(':param {} {}:', arg.type_hint(), arg.name)
|
||||
builder.current_indent -= 1 # It will auto-indent (':')
|
||||
|
||||
# We also want to know what type this request returns
|
||||
# or to which type this constructor belongs to
|
||||
builder.writeln()
|
||||
if tlobject.is_function:
|
||||
builder.write(':returns {}: ', tlobject.result)
|
||||
else:
|
||||
builder.write('Constructor for {}: ', tlobject.result)
|
||||
|
||||
constructors = type_constructors[tlobject.result]
|
||||
if not constructors:
|
||||
builder.writeln('This type has no constructors.')
|
||||
elif len(constructors) == 1:
|
||||
builder.writeln('Instance of {}.',
|
||||
constructors[0].class_name)
|
||||
else:
|
||||
builder.writeln('Instance of either {}.', ', '.join(
|
||||
c.class_name for c in constructors))
|
||||
|
||||
builder.writeln('"""')
|
||||
|
||||
builder.writeln('super().__init__()')
|
||||
# Functions have a result object and are confirmed by default
|
||||
# We also want to know what type this request returns
|
||||
# or to which type this constructor belongs to
|
||||
builder.writeln()
|
||||
if tlobject.is_function:
|
||||
builder.writeln('self.result = None')
|
||||
builder.writeln('self.content_related = True')
|
||||
builder.write(':returns {}: ', tlobject.result)
|
||||
else:
|
||||
builder.write('Constructor for {}: ', tlobject.result)
|
||||
|
||||
constructors = type_constructors[tlobject.result]
|
||||
if not constructors:
|
||||
builder.writeln('This type has no constructors.')
|
||||
elif len(constructors) == 1:
|
||||
builder.writeln('Instance of {}.',
|
||||
constructors[0].class_name)
|
||||
else:
|
||||
builder.writeln('Instance of either {}.', ', '.join(
|
||||
c.class_name for c in constructors))
|
||||
|
||||
builder.writeln('"""')
|
||||
|
||||
# Set the arguments
|
||||
if tlobject.real_args:
|
||||
builder.writeln()
|
||||
|
||||
for arg in tlobject.real_args:
|
||||
if not arg.can_be_inferred:
|
||||
builder.writeln('self.{0} = {0} # type: {1}',
|
||||
|
@ -234,7 +233,7 @@ def _write_class_init(tlobject, type_constructors, builder):
|
|||
|
||||
def _write_resolve(tlobject, builder):
|
||||
if any(arg.type in AUTO_CASTS for arg in tlobject.real_args):
|
||||
builder.writeln('def resolve(self, client, utils):')
|
||||
builder.writeln('async def resolve(self, client, utils):')
|
||||
for arg in tlobject.real_args:
|
||||
ac = AUTO_CASTS.get(arg.type, None)
|
||||
if not ac:
|
||||
|
@ -333,7 +332,7 @@ def _write_from_reader(tlobject, builder):
|
|||
'{0}=_{0}'.format(a.name) for a in tlobject.real_args))
|
||||
|
||||
|
||||
def _write_on_response(tlobject, builder):
|
||||
def _write_read_result(tlobject, builder):
|
||||
# Only requests can have a different response that's not their
|
||||
# serialized body, that is, we'll be setting their .result.
|
||||
#
|
||||
|
@ -354,9 +353,10 @@ def _write_on_response(tlobject, builder):
|
|||
return
|
||||
|
||||
builder.end_block()
|
||||
builder.writeln('def on_response(self, reader):')
|
||||
builder.writeln('@staticmethod')
|
||||
builder.writeln('def read_result(reader):')
|
||||
builder.writeln('reader.read_int() # Vector ID')
|
||||
builder.writeln('self.result = [reader.read_{}() '
|
||||
builder.writeln('return [reader.read_{}() '
|
||||
'for _ in range(reader.read_int())]', m.group(1))
|
||||
|
||||
|
||||
|
@ -447,7 +447,7 @@ def _write_arg_to_bytes(builder, arg, args, name=None):
|
|||
builder.write("struct.pack('<d', {})", name)
|
||||
|
||||
elif 'string' == arg.type:
|
||||
builder.write('TLObject.serialize_bytes({})', name)
|
||||
builder.write('self.serialize_bytes({})', name)
|
||||
|
||||
elif 'Bool' == arg.type:
|
||||
# 0x997275b5 if boolean else 0xbc799737
|
||||
|
@ -457,10 +457,10 @@ def _write_arg_to_bytes(builder, arg, args, name=None):
|
|||
pass # These are actually NOT written! Only used for flags
|
||||
|
||||
elif 'bytes' == arg.type:
|
||||
builder.write('TLObject.serialize_bytes({})', name)
|
||||
builder.write('self.serialize_bytes({})', name)
|
||||
|
||||
elif 'date' == arg.type: # Custom format
|
||||
builder.write('TLObject.serialize_datetime({})', name)
|
||||
builder.write('self.serialize_datetime({})', name)
|
||||
|
||||
else:
|
||||
# Else it may be a custom type
|
||||
|
@ -645,9 +645,9 @@ def generate_tlobjects(tlobjects, layer, import_depth, output_dir):
|
|||
namespace_types[tlobject.namespace].append(tlobject)
|
||||
type_constructors[tlobject.result].append(tlobject)
|
||||
|
||||
_write_modules(get_file('functions'), import_depth,
|
||||
_write_modules(get_file('functions'), import_depth, 'TLRequest',
|
||||
namespace_functions, type_constructors)
|
||||
_write_modules(get_file('types'), import_depth,
|
||||
_write_modules(get_file('types'), import_depth, 'TLObject',
|
||||
namespace_types, type_constructors)
|
||||
|
||||
filename = os.path.join(get_file('all_tlobjects.py'))
|
||||
|
|
Loading…
Reference in New Issue
Block a user