diff --git a/telethon/extensions/binary_reader.py b/telethon/extensions/binary_reader.py index 43232b0b..2355c6a4 100644 --- a/telethon/extensions/binary_reader.py +++ b/telethon/extensions/binary_reader.py @@ -129,13 +129,10 @@ class BinaryReader: return False # If there was still no luck, give up + self.seek(-4) # Go back raise TypeNotFoundError(constructor_id) - # Create an empty instance of the class and - # fill it with the read attributes - result = clazz.empty() - result.on_response(self) - return result + return clazz.from_reader(self) def tgread_vector(self): """Reads a vector (a list) of Telegram objects""" diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 672c0a0f..e847873f 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -15,7 +15,7 @@ class TcpClient: if isinstance(timeout, timedelta): self.timeout = timeout.seconds - elif isinstance(timeout, int) or isinstance(timeout, float): + elif isinstance(timeout, (int, float)): self.timeout = float(timeout) else: raise ValueError('Invalid timeout type', type(timeout)) diff --git a/telethon/network/authenticator.py b/telethon/network/authenticator.py index f46f5430..766a4c00 100644 --- a/telethon/network/authenticator.py +++ b/telethon/network/authenticator.py @@ -124,7 +124,6 @@ async def _do_authentication(connection): raise AssertionError(server_dh_inner) if server_dh_inner.nonce != res_pq.nonce: - print(server_dh_inner.nonce, res_pq.nonce) raise SecurityError('Invalid nonce in encrypted answer') if server_dh_inner.server_nonce != res_pq.server_nonce: diff --git a/telethon/network/connection.py b/telethon/network/connection.py index 77a3c87b..270b9451 100644 --- a/telethon/network/connection.py +++ b/telethon/network/connection.py @@ -143,28 +143,25 @@ class Connection: # TODO We don't want another call to this method that could # potentially await on another self.read(n). Is this guaranteed # by asyncio? - packet_length_bytes = await self.read(4) - packet_length = int.from_bytes(packet_length_bytes, 'little') + packet_len_seq = await self.read(8) # 4 and 4 + packet_len, seq = struct.unpack('= 127: - length = int.from_bytes(self.read(3) + b'\0', 'little') + length = struct.unpack(' self.session.flood_sleep_threshold | 0: + raise + + self._logger.debug( + 'Sleep of %d seconds below threshold, sleeping' % e.seconds + ) + sleep(e.seconds) # Some really basic functionality @@ -609,10 +614,8 @@ class TelegramBareClient: cdn_decrypter = None try: - offset_index = 0 + offset = 0 while True: - offset = offset_index * part_size - try: if cdn_decrypter: result = await cdn_decrypter.get_file() @@ -633,7 +636,7 @@ class TelegramBareClient: client = await self._get_exported_client(e.new_dc) continue - offset_index += 1 + offset += part_size # If we have received no data (0 bytes), the file is over # So there is nothing left to download and write @@ -670,6 +673,9 @@ class TelegramBareClient: def add_update_handler(self, handler): """Adds an update handler (a function which takes a TLObject, an update, as its parameter) and listens for updates""" + if not self.updates.get_workers: + warnings.warn("There are no update workers running, so adding an update handler will have no effect.") + sync = not self.updates.handlers self.updates.handlers.append(handler) if sync: diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index af89dc50..61c730d8 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -15,6 +15,7 @@ from .errors import ( ) from .network import ConnectionMode from .tl import TLObject +from .tl.custom import Draft from .tl.entity_database import EntityDatabase from .tl.functions.account import ( GetPasswordRequest @@ -28,8 +29,8 @@ from .tl.functions.contacts import ( ) from .tl.functions.messages import ( GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest, - SendMessageRequest, GetChatsRequest -) + SendMessageRequest, GetChatsRequest, + GetAllDraftsRequest) from .tl.functions import channels from .tl.functions import messages @@ -46,7 +47,7 @@ from .tl.types import ( InputMediaUploadedDocument, InputMediaUploadedPhoto, InputPeerEmpty, Message, MessageMediaContact, MessageMediaDocument, MessageMediaPhoto, InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID, - UpdateNewMessage, UpdateShortSentMessage, + UpdateNewChannelMessage, UpdateNewMessage, UpdateShortSentMessage, PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel) from .tl.types.messages import DialogsSlice @@ -227,7 +228,7 @@ class TelegramClient(TelegramBareClient): if limit is None: limit = float('inf') - dialogs = {} # Use Dialog.top_message as identifier to avoid dupes + dialogs = {} # Use peer id as identifier to avoid dupes messages = {} # Used later for sorting TODO also return these? entities = {} while len(dialogs) < limit: @@ -242,7 +243,7 @@ class TelegramClient(TelegramBareClient): break for d in r.dialogs: - dialogs[d.top_message] = d + dialogs[utils.get_peer_id(d.peer, True)] = d for m in r.messages: messages[m.id] = m @@ -277,9 +278,20 @@ class TelegramClient(TelegramBareClient): [utils.find_user_or_chat(d.peer, entities, entities) for d in ds] ) - # endregion + async def get_drafts(self): # TODO: Ability to provide a `filter` + """ + Gets all open draft messages. - # region Message requests + Returns a list of custom `Draft` objects that are easy to work with: You can call + `draft.set_message('text')` to change the message, or delete it through `draft.delete()`. + + :return List[telethon.tl.custom.Draft]: A list of open drafts + """ + response = await self(GetAllDraftsRequest()) + self.session.process_entities(response) + self.session.generate_sequence(response.seq) + drafts = [Draft._from_update(self, u) for u in response.updates] + return drafts async def send_message(self, entity, @@ -322,7 +334,7 @@ class TelegramClient(TelegramBareClient): break for update in result.updates: - if isinstance(update, UpdateNewMessage): + if isinstance(update, (UpdateNewChannelMessage, UpdateNewMessage)): if update.message.id == msg_id: return update.message @@ -463,9 +475,13 @@ class TelegramClient(TelegramBareClient): async def send_file(self, entity, file, caption='', force_document=False, progress_callback=None, reply_to=None, + attributes=None, **kwargs): """Sends a file to the specified entity. The file may either be a path, a byte array, or a stream. + Note that if a byte array or a stream is given, a filename + or its type won't be inferred, and it will be sent as an + "unnamed application/octet-stream". An optional caption can also be specified for said file. @@ -482,6 +498,10 @@ class TelegramClient(TelegramBareClient): The "reply_to" parameter works exactly as the one on .send_message. + If "attributes" is set to be a list of DocumentAttribute's, these + will override the automatically inferred ones (so that you can + modify the file name of the file sent for instance). + If "is_voice_note" in kwargs, despite its value, and the file is sent as a document, it will be sent as a voice note. @@ -512,16 +532,28 @@ class TelegramClient(TelegramBareClient): # Determine mime-type and attributes # Take the first element by using [0] since it returns a tuple mime_type = guess_type(file)[0] - attributes = [ + attr_dict = { + DocumentAttributeFilename: DocumentAttributeFilename(os.path.basename(file)) # TODO If the input file is an audio, find out: # Performer and song title and add DocumentAttributeAudio - ] + } else: - attributes = [DocumentAttributeFilename('unnamed')] + attr_dict = { + DocumentAttributeFilename: + DocumentAttributeFilename('unnamed') + } if 'is_voice_note' in kwargs: - attributes.append(DocumentAttributeAudio(0, voice=True)) + attr_dict[DocumentAttributeAudio] = \ + DocumentAttributeAudio(0, voice=True) + + # Now override the attributes if any. As we have a dict of + # {cls: instance}, we can override any class with the list + # of attributes provided by the user easily. + if attributes: + for a in attributes: + attr_dict[type(a)] = a # Ensure we have a mime type, any; but it cannot be None # 'The "octet-stream" subtype is used to indicate that a body @@ -532,7 +564,7 @@ class TelegramClient(TelegramBareClient): media = InputMediaUploadedDocument( file=file_handle, mime_type=mime_type, - attributes=attributes, + attributes=list(attr_dict.values()), caption=caption ) @@ -852,19 +884,19 @@ class TelegramClient(TelegramBareClient): # crc32(b'InputPeer') and crc32(b'Peer') type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): ie = await self.get_input_entity(entity) - result = None if isinstance(ie, InputPeerUser): - result = await self(GetUsersRequest([ie])) + await self(GetUsersRequest([ie])) elif isinstance(ie, InputPeerChat): - result = await self(GetChatsRequest([ie.chat_id])) + await self(GetChatsRequest([ie.chat_id])) elif isinstance(ie, InputPeerChannel): - result = await self(GetChannelsRequest([ie])) - if result: - self.session.process_entities(result) - try: - return self.session.entities[ie] - except KeyError: - pass + await self(GetChannelsRequest([ie])) + try: + # session.process_entities has been called in the MtProtoSender + # with the result of these calls, so they should now be on the + # entities database. + return self.session.entities[ie] + except KeyError: + pass if isinstance(entity, str): return await self._get_entity_from_string(entity) @@ -880,10 +912,11 @@ class TelegramClient(TelegramBareClient): phone = EntityDatabase.parse_phone(string) if phone: entity = phone - self.session.process_entities(await self(GetContactsRequest(0))) + await self(GetContactsRequest(0)) else: entity = string.strip('@').lower() - self.session.process_entities(await self(ResolveUsernameRequest(entity))) + await self(ResolveUsernameRequest(entity)) + # MtProtoSender will call .process_entities on the requests made try: return self.session.entities[entity] @@ -930,9 +963,17 @@ class TelegramClient(TelegramBareClient): ) if self.session.save_entities: - # Not found, look in the dialogs (this will save the users) - await self.get_dialogs(limit=None) - + # Not found, look in the latest dialogs. + # This is useful if for instance someone just sent a message but + # the updates didn't specify who, as this person or chat should + # be in the latest dialogs. + await self(GetDialogsRequest( + offset_date=None, + offset_id=0, + offset_peer=InputPeerEmpty(), + limit=0, + exclude_pinned=True + )) try: return self.session.entities.get_input_entity(peer) except KeyError: diff --git a/telethon/tl/custom/__init__.py b/telethon/tl/custom/__init__.py new file mode 100644 index 00000000..40914f16 --- /dev/null +++ b/telethon/tl/custom/__init__.py @@ -0,0 +1 @@ +from .draft import Draft diff --git a/telethon/tl/custom/draft.py b/telethon/tl/custom/draft.py new file mode 100644 index 00000000..c50baa78 --- /dev/null +++ b/telethon/tl/custom/draft.py @@ -0,0 +1,80 @@ +from ..functions.messages import SaveDraftRequest +from ..types import UpdateDraftMessage + + +class Draft: + """ + Custom class that encapsulates a draft on the Telegram servers, providing + an abstraction to change the message conveniently. The library will return + instances of this class when calling `client.get_drafts()`. + """ + def __init__(self, client, peer, draft): + self._client = client + self._peer = peer + + self.text = draft.message + self.date = draft.date + self.no_webpage = draft.no_webpage + self.reply_to_msg_id = draft.reply_to_msg_id + self.entities = draft.entities + + @classmethod + def _from_update(cls, client, update): + if not isinstance(update, UpdateDraftMessage): + raise ValueError( + 'You can only create a new `Draft` from a corresponding ' + '`UpdateDraftMessage` object.' + ) + + return cls(client=client, peer=update.peer, draft=update.draft) + + @property + def entity(self): + return self._client.get_entity(self._peer) + + @property + def input_entity(self): + return self._client.get_input_entity(self._peer) + + def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None): + """ + Changes the draft message on the Telegram servers. The changes are + reflected in this object. Changing only individual attributes like for + example the `reply_to_msg_id` should be done by providing the current + values of this object, like so: + + draft.set_message( + draft.text, + no_webpage=draft.no_webpage, + reply_to_msg_id=NEW_VALUE, + entities=draft.entities + ) + + :param str text: New text of the draft + :param bool no_webpage: Whether to attach a web page preview + :param int reply_to_msg_id: Message id to reply to + :param list entities: A list of formatting entities + :return bool: `True` on success + """ + result = self._client(SaveDraftRequest( + peer=self._peer, + message=text, + no_webpage=no_webpage, + reply_to_msg_id=reply_to_msg_id, + entities=entities + )) + + if result: + self.text = text + self.no_webpage = no_webpage + self.reply_to_msg_id = reply_to_msg_id + self.entities = entities + + return result + + def delete(self): + """ + Deletes this draft + :return bool: `True` on success + """ + return self.set_message(text='') diff --git a/telethon/tl/entity_database.py b/telethon/tl/entity_database.py index da105616..a932d0da 100644 --- a/telethon/tl/entity_database.py +++ b/telethon/tl/entity_database.py @@ -70,9 +70,7 @@ class EntityDatabase: getattr(p, 'access_hash', 0) # chats won't have hash if self.enabled_full: - if isinstance(e, User) \ - or isinstance(e, Chat) \ - or isinstance(e, Channel): + if isinstance(e, (User, Chat, Channel)): new.append(e) except ValueError: pass @@ -118,47 +116,64 @@ class EntityDatabase: if phone: self._username_id[phone] = marked_id - def __getitem__(self, key): - """Accepts a digit only string as phone number, - otherwise it's treated as an username. + def _parse_key(self, key): + """Parses the given string, integer or TLObject key into a + marked user ID ready for use on self._entities. - If an integer is given, it's treated as the ID of the desired User. - The ID given won't try to be guessed as the ID of a chat or channel, - as there may be an user with that ID, and it would be unreliable. + If a callable key is given, the entity will be passed to the + function, and if it returns a true-like value, the marked ID + for such entity will be returned. - If a Peer is given (PeerUser, PeerChat, PeerChannel), - its specific entity is retrieved as User, Chat or Channel. - Note that megagroups are channels with .megagroup = True. + Raises ValueError if it cannot be parsed. """ if isinstance(key, str): phone = EntityDatabase.parse_phone(key) - if phone: - return self._phone_id[phone] - else: - key = key.lstrip('@').lower() - return self._entities[self._username_id[key]] + try: + if phone: + return self._phone_id[phone] + else: + return self._username_id[key.lstrip('@').lower()] + except KeyError as e: + raise ValueError() from e if isinstance(key, int): - return self._entities[key] # normal IDs are assumed users + return key # normal IDs are assumed users if isinstance(key, TLObject): - sc = type(key).SUBCLASS_OF_ID - if sc == 0x2d45687: - # Subclass of "Peer" - return self._entities[utils.get_peer_id(key, add_mark=True)] - elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}: - # Subclass of "User", "Chat" or "Channel" - return key + return utils.get_peer_id(key, add_mark=True) - raise KeyError(key) + if callable(key): + for k, v in self._entities.items(): + if key(v): + return k + + raise ValueError() + + def __getitem__(self, key): + """See the ._parse_key() docstring for possible values of the key""" + try: + return self._entities[self._parse_key(key)] + except (ValueError, KeyError) as e: + raise KeyError(key) from e def __delitem__(self, key): - target = self[key] - del self._entities[key] - if getattr(target, 'username'): - del self._username_id[target.username] + try: + old = self._entities.pop(self._parse_key(key)) + # Try removing the username and phone (if pop didn't fail), + # since the entity may have no username or phone, just ignore + # errors. It should be there if we popped the entity correctly. + try: + del self._username_id[getattr(old, 'username', None)] + except KeyError: + pass - # TODO Allow search by name by tokenizing the input and return a list + try: + del self._phone_id[getattr(old, 'phone', None)] + except KeyError: + pass + + except (ValueError, KeyError) as e: + raise KeyError(key) from e @staticmethod def parse_phone(phone): @@ -172,8 +187,10 @@ class EntityDatabase: def get_input_entity(self, peer): try: - i, k = utils.get_peer_id(peer, add_mark=True, get_kind=True) - h = self._input_entities[i] + i = utils.get_peer_id(peer, add_mark=True) + h = self._input_entities[i] # we store the IDs marked + i, k = utils.resolve_id(i) # removes the mark and returns kind + if k == PeerUser: return InputPeerUser(i, h) elif k == PeerChat: diff --git a/telethon/tl/gzip_packed.py b/telethon/tl/gzip_packed.py index 05453d4b..a7d09188 100644 --- a/telethon/tl/gzip_packed.py +++ b/telethon/tl/gzip_packed.py @@ -34,5 +34,5 @@ class GzipPacked(TLObject): @staticmethod def read(reader): - reader.read_int(signed=False) # code + assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID return gzip.decompress(reader.tgread_bytes()) diff --git a/telethon/tl/session.py b/telethon/tl/session.py index b722144e..48c38211 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -36,6 +36,7 @@ class Session: self.lang_pack = session.lang_pack self.report_errors = session.report_errors self.save_entities = session.save_entities + self.flood_sleep_threshold = session.flood_sleep_threshold else: # str / None self.session_user_id = session_user_id @@ -49,6 +50,7 @@ class Session: self.lang_pack = '' self.report_errors = True self.save_entities = True + self.flood_sleep_threshold = 60 self.id = helpers.generate_random_long(signed=False) self._sequence = 0 diff --git a/telethon/tl/tlobject.py b/telethon/tl/tlobject.py index 1866ba68..8deb59e3 100644 --- a/telethon/tl/tlobject.py +++ b/telethon/tl/tlobject.py @@ -1,3 +1,4 @@ +from datetime import datetime from threading import Event @@ -19,18 +20,14 @@ class TLObject: """ if indent is None: if isinstance(obj, TLObject): - children = obj.to_dict(recursive=False) - if children: - return '{}: {}'.format( - type(obj).__name__, TLObject.pretty_format(children) - ) - else: - return type(obj).__name__ + return '{}({})'.format(type(obj).__name__, ', '.join( + '{}={}'.format(k, TLObject.pretty_format(v)) + for k, v in obj.to_dict(recursive=False).items() + )) if isinstance(obj, dict): return '{{{}}}'.format(', '.join( - '{}: {}'.format( - k, TLObject.pretty_format(v) - ) for k, v in obj.items() + '{}: {}'.format(k, TLObject.pretty_format(v)) + for k, v in obj.items() )) elif isinstance(obj, str) or isinstance(obj, bytes): return repr(obj) @@ -38,31 +35,36 @@ class TLObject: return '[{}]'.format( ', '.join(TLObject.pretty_format(x) for x in obj) ) + elif isinstance(obj, datetime): + return 'datetime.fromtimestamp({})'.format(obj.timestamp()) else: - return str(obj) + return repr(obj) else: result = [] - if isinstance(obj, TLObject): - result.append(type(obj).__name__) - children = obj.to_dict(recursive=False) - if children: - result.append(': ') - result.append(TLObject.pretty_format( - obj.to_dict(recursive=False), indent - )) + if isinstance(obj, TLObject) or isinstance(obj, dict): + if isinstance(obj, dict): + d = obj + start, end, sep = '{', '}', ': ' + else: + d = obj.to_dict(recursive=False) + start, end, sep = '(', ')', '=' + result.append(type(obj).__name__) - elif isinstance(obj, dict): - result.append('{\n') - indent += 1 - for k, v in obj.items(): + result.append(start) + if d: + result.append('\n') + indent += 1 + for k, v in d.items(): + result.append('\t' * indent) + result.append(k) + result.append(sep) + result.append(TLObject.pretty_format(v, indent)) + result.append(',\n') + result.pop() # last ',\n' + indent -= 1 + result.append('\n') result.append('\t' * indent) - result.append(k) - result.append(': ') - result.append(TLObject.pretty_format(v, indent)) - result.append(',\n') - indent -= 1 - result.append('\t' * indent) - result.append('}') + result.append(end) elif isinstance(obj, str) or isinstance(obj, bytes): result.append(repr(obj)) @@ -78,8 +80,13 @@ class TLObject: result.append('\t' * indent) result.append(']') + elif isinstance(obj, datetime): + result.append('datetime.fromtimestamp(') + result.append(repr(obj.timestamp())) + result.append(')') + else: - result.append(str(obj)) + result.append(repr(obj)) return ''.join(result) @@ -121,5 +128,6 @@ class TLObject: def to_bytes(self): return b'' - def on_response(self, reader): - pass + @staticmethod + def from_reader(reader): + return TLObject() diff --git a/telethon/update_state.py b/telethon/update_state.py index 995e3eb2..8dd2ffad 100644 --- a/telethon/update_state.py +++ b/telethon/update_state.py @@ -1,4 +1,5 @@ import logging +import pickle from collections import deque from datetime import datetime from threading import RLock, Event, Thread @@ -27,6 +28,7 @@ class UpdateState: self._updates_lock = RLock() self._updates_available = Event() self._updates = deque() + self._latest_updates = deque(maxlen=10) self._logger = logging.getLogger(__name__) @@ -141,6 +143,26 @@ class UpdateState: self._state.pts = pts + # TODO There must be a better way to handle updates rather than + # keeping a queue with the latest updates only, and handling + # the 'pts' correctly should be enough. However some updates + # like UpdateUserStatus (even inside UpdateShort) will be called + # repeatedly very often if invoking anything inside an update + # handler. TODO Figure out why. + """ + client = TelegramClient('anon', api_id, api_hash, update_workers=1) + client.connect() + def handle(u): + client.get_me() + client.add_update_handler(handle) + input('Enter to exit.') + """ + data = pickle.dumps(update.to_dict()) + if data in self._latest_updates: + return # Duplicated too + + self._latest_updates.append(data) + if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates') # Expand "Updates" into "Update", and pass these to callbacks. # Since .users and .chats have already been processed, we @@ -149,8 +171,7 @@ class UpdateState: self._updates.append(update.update) self._updates_available.set() - elif isinstance(update, tl.Updates) or \ - isinstance(update, tl.UpdatesCombined): + elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): self._updates.extend(update.updates) self._updates_available.set() diff --git a/telethon/utils.py b/telethon/utils.py index c4c3182c..d8bfb89f 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -20,8 +20,8 @@ from .tl.types import ( GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty, InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty, FileLocationUnavailable, InputMediaUploadedDocument, - InputMediaUploadedPhoto, - DocumentAttributeFilename) + InputMediaUploadedPhoto, DocumentAttributeFilename, photos +) def get_display_name(entity): @@ -37,7 +37,7 @@ def get_display_name(entity): else: return '(No name)' - if isinstance(entity, Chat) or isinstance(entity, Channel): + if isinstance(entity, (Chat, Channel)): return entity.title return '(unknown)' @@ -50,8 +50,7 @@ def get_extension(media): """Gets the corresponding extension for any Telegram media""" # Photos are always compressed as .jpg by Telegram - if (isinstance(media, UserProfilePhoto) or isinstance(media, ChatPhoto) or - isinstance(media, MessageMediaPhoto)): + if isinstance(media, (UserProfilePhoto, ChatPhoto, MessageMediaPhoto)): return '.jpg' # Documents will come with a mime type @@ -87,12 +86,10 @@ def get_input_peer(entity, allow_self=True): else: return InputPeerUser(entity.id, entity.access_hash) - if any(isinstance(entity, c) for c in ( - Chat, ChatEmpty, ChatForbidden)): + if isinstance(entity, (Chat, ChatEmpty, ChatForbidden)): return InputPeerChat(entity.id) - if any(isinstance(entity, c) for c in ( - Channel, ChannelForbidden)): + if isinstance(entity, (Channel, ChannelForbidden)): return InputPeerChannel(entity.id, entity.access_hash) # Less common cases @@ -122,7 +119,7 @@ def get_input_channel(entity): if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel') return entity - if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden): + if isinstance(entity, (Channel, ChannelForbidden)): return InputChannel(entity.id, entity.access_hash) if isinstance(entity, InputPeerChannel): @@ -188,6 +185,9 @@ def get_input_photo(photo): if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto') return photo + if isinstance(photo, photos.Photo): + photo = photo.photo + if isinstance(photo, Photo): return InputPhoto(id=photo.id, access_hash=photo.access_hash) @@ -263,7 +263,7 @@ def get_input_media(media, user_caption=None, is_photo=False): if isinstance(media, MessageMediaGame): return InputMediaGame(id=media.game.id) - if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto): + if isinstance(media, (ChatPhoto, UserProfilePhoto)): if isinstance(media.photo_big, FileLocationUnavailable): return get_input_media(media.photo_small, is_photo=True) else: @@ -288,10 +288,9 @@ def get_input_media(media, user_caption=None, is_photo=False): venue_id=media.venue_id ) - if any(isinstance(media, t) for t in ( + if isinstance(media, ( MessageMediaEmpty, MessageMediaUnsupported, - FileLocationUnavailable, ChatPhotoEmpty, - UserProfilePhotoEmpty)): + ChatPhotoEmpty, UserProfilePhotoEmpty, FileLocationUnavailable)): return InputMediaEmpty() if isinstance(media, Message): @@ -300,16 +299,14 @@ def get_input_media(media, user_caption=None, is_photo=False): _raise_cast_fail(media, 'InputMedia') -def get_peer_id(peer, add_mark=False, get_kind=False): +def get_peer_id(peer, add_mark=False): """Finds the ID of the given peer, and optionally converts it to the "bot api" format if 'add_mark' is set to True. - - If 'get_kind', the kind will be returned as a second value. """ # First we assert it's a Peer TLObject, or early return for integers if not isinstance(peer, TLObject): if isinstance(peer, int): - return (peer, PeerUser) if get_kind else peer + return peer else: _raise_cast_fail(peer, 'int') @@ -318,25 +315,20 @@ def get_peer_id(peer, add_mark=False, get_kind=False): peer = get_input_peer(peer, allow_self=False) # Set the right ID/kind, or raise if the TLObject is not recognised - i, k = None, None - if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser): - i, k = peer.user_id, PeerUser - elif isinstance(peer, PeerChat) or isinstance(peer, InputPeerChat): - i, k = peer.chat_id, PeerChat - elif isinstance(peer, PeerChannel) or isinstance(peer, InputPeerChannel): - i, k = peer.channel_id, PeerChannel - else: - _raise_cast_fail(peer, 'int') - - if add_mark: - if k == PeerChat: - i = -i - elif k == PeerChannel: + if isinstance(peer, (PeerUser, InputPeerUser)): + return peer.user_id + elif isinstance(peer, (PeerChat, InputPeerChat)): + return -peer.chat_id if add_mark else peer.chat_id + elif isinstance(peer, (PeerChannel, InputPeerChannel)): + i = peer.channel_id + if add_mark: # Concat -100 through math tricks, .to_supergroup() on Madeline # IDs will be strictly positive -> log works - i = -(i + pow(10, math.floor(math.log10(i) + 3))) + return -(i + pow(10, math.floor(math.log10(i) + 3))) + else: + return i - return (i, k) if get_kind else i # return kind only if get_kind + _raise_cast_fail(peer, 'int') def resolve_id(marked_id): @@ -375,11 +367,7 @@ def find_user_or_chat(peer, users, chats): def get_appropriated_part_size(file_size): """Gets the appropriated part size when uploading or downloading files, given an initial file size""" - if file_size <= 1048576: # 1MB - return 32 - if file_size <= 10485760: # 10MB - return 64 - if file_size <= 393216000: # 375MB + if file_size <= 104857600: # 100MB return 128 if file_size <= 786432000: # 750MB return 256 diff --git a/telethon_generator/tl_generator.py b/telethon_generator/tl_generator.py index e76dffaa..0e4f0013 100644 --- a/telethon_generator/tl_generator.py +++ b/telethon_generator/tl_generator.py @@ -143,7 +143,7 @@ class TLGenerator: builder.writeln( 'from {}.utils import get_input_peer, ' 'get_input_channel, get_input_user, ' - 'get_input_media'.format('.' * depth) + 'get_input_media, get_input_photo'.format('.' * depth) ) # Import 'os' for those needing access to 'os.urandom()' @@ -335,32 +335,28 @@ class TLGenerator: builder.writeln('))') builder.end_block() - # Write the empty() function, which returns an "empty" - # instance, in which all attributes are set to None + # Write the static from_reader(reader) function builder.writeln('@staticmethod') - builder.writeln('def empty():') + builder.writeln('def from_reader(reader):') + for arg in tlobject.args: + TLGenerator.write_read_code( + builder, arg, tlobject.args, name='_' + arg.name + ) + builder.writeln('return {}({})'.format( - tlobject.class_name(), ', '.join('None' for _ in range(len(args))) + tlobject.class_name(), ', '.join( + '{0}=_{0}'.format(a.name) for a in tlobject.sorted_args() + if not a.flag_indicator and not a.generic_definition + ) )) builder.end_block() - # Write the on_response(self, reader) function - builder.writeln('def on_response(self, reader):') - # Do not read constructor's ID, since - # that's already been read somewhere else + # Only requests can have a different response that's not their + # serialized body, that is, we'll be setting their .result. if tlobject.is_function: + builder.writeln('def on_response(self, reader):') TLGenerator.write_request_result_code(builder, tlobject) - else: - if tlobject.args: - for arg in tlobject.args: - TLGenerator.write_onresponse_code( - builder, arg, tlobject.args - ) - else: - # If there were no arguments, we still need an - # on_response method, and hence "pass" if empty - builder.writeln('pass') - builder.end_block() + builder.end_block() # Write the __str__(self) and stringify(self) functions builder.writeln('def __str__(self):') @@ -406,6 +402,8 @@ class TLGenerator: TLGenerator.write_get_input(builder, arg, 'get_input_user') elif arg.type == 'InputMedia' and tlobject.is_function: TLGenerator.write_get_input(builder, arg, 'get_input_media') + elif arg.type == 'InputPhoto' and tlobject.is_function: + TLGenerator.write_get_input(builder, arg, 'get_input_photo') else: builder.writeln('self.{0} = {0}'.format(arg.name)) @@ -549,9 +547,10 @@ class TLGenerator: return True # Something was written @staticmethod - def write_onresponse_code(builder, arg, args, name=None): + def write_read_code(builder, arg, args, name): """ - Writes the receive code for the given argument + Writes the read code for the given argument, setting the + arg.name variable to its read value. :param builder: The source code builder :param arg: The argument to write @@ -565,12 +564,17 @@ class TLGenerator: if arg.generic_definition: return # Do nothing, this only specifies a later type - if name is None: - name = 'self.{}'.format(arg.name) - # The argument may be a flag, only write that flag was given! was_flag = False if arg.is_flag: + # Treat 'true' flags as a special case, since they're true if + # they're set, and nothing else needs to actually be read. + if 'true' == arg.type: + builder.writeln( + '{} = bool(flags & {})'.format(name, 1 << arg.flag_index) + ) + return + was_flag = True builder.writeln('if flags & {}:'.format( 1 << arg.flag_index @@ -585,11 +589,10 @@ class TLGenerator: builder.writeln("reader.read_int()") builder.writeln('{} = []'.format(name)) - builder.writeln('_len = reader.read_int()') - builder.writeln('for _ in range(_len):') + builder.writeln('for _ in range(reader.read_int()):') # Temporary disable .is_vector, not to enter this if again arg.is_vector = False - TLGenerator.write_onresponse_code(builder, arg, args, name='_x') + TLGenerator.write_read_code(builder, arg, args, name='_x') builder.writeln('{}.append(_x)'.format(name)) arg.is_vector = True @@ -642,7 +645,10 @@ class TLGenerator: builder.end_block() if was_flag: - builder.end_block() + builder.current_indent -= 1 + builder.writeln('else:') + builder.writeln('{} = None'.format(name)) + builder.current_indent -= 1 # Restore .is_flag arg.is_flag = True diff --git a/telethon_tests/crypto_test.py b/telethon_tests/crypto_test.py index cec18084..e11704a4 100644 --- a/telethon_tests/crypto_test.py +++ b/telethon_tests/crypto_test.py @@ -107,17 +107,17 @@ class CryptoTests(unittest.TestCase): @staticmethod def test_generate_key_data_from_nonce(): - server_nonce = b'I am the server nonce.' - new_nonce = b'I am a new calculated nonce.' + server_nonce = int.from_bytes(b'The 16-bit nonce', byteorder='little') + new_nonce = int.from_bytes(b'The new, calculated 32-bit nonce', byteorder='little') key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce) - expected_key = b'?\xc4\xbd\xdf\rWU\x8a\xf5\x0f+V\xdc\x96up\x1d\xeeG\x00\x81|\x1eg\x8a\x8f{\xf0y\x80\xda\xde' - expected_iv = b'Q\x9dpZ\xb7\xdd\xcb\x82_\xfa\xf4\x90\xecn\x10\x9cD\xd2\x01\x8d\x83\xa0\xa4^\xb8\x91,\x7fI am' + expected_key = b'/\xaa\x7f\xa1\xfcs\xef\xa0\x99zh\x03M\xa4\x8e\xb4\xab\x0eE]b\x95|\xfe\xc0\xf8\x1f\xd4\xa0\xd4\xec\x91' + expected_iv = b'\xf7\xae\xe3\xc8+=\xc2\xb8\xd1\xe1\x1b\x0e\x10\x07\x9fn\x9e\xdc\x960\x05\xf9\xea\xee\x8b\xa1h The ' assert key == expected_key, 'Key ("{}") does not equal expected ("{}")'.format( key, expected_key) - assert iv == expected_iv, 'Key ("{}") does not equal expected ("{}")'.format( - key, expected_iv) + assert iv == expected_iv, 'IV ("{}") does not equal expected ("{}")'.format( + iv, expected_iv) @staticmethod def test_fingerprint_from_key():