diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index d41a8b8e..6174ecc1 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -493,10 +493,22 @@ class TelegramBareClient: # "A container may only be accepted or # rejected by the other party as a whole." return None - elif len(requests) == 1: - return requests[0].result - else: - return [x.result for x in requests] + + # Save all input entities we know of + entities = [] + results = [] + for x in requests: + y = x.result + results.append(y) + if hasattr(y, 'chats') and hasattr(y.chats, '__iter__'): + entities.extend(y.chats) + if hasattr(y, 'users') and hasattr(y.users, '__iter__'): + entities.extend(y.users) + + if self.session.add_entities(entities): + self.session.save() # Save if any new entities got added + + return results[0] if len(results) == 1 else results except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: diff --git a/telethon/tl/session.py b/telethon/tl/session.py index b8d89598..6e6b339e 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -2,11 +2,15 @@ import json import os import platform import time -from threading import Lock from base64 import b64encode, b64decode from os.path import isfile as file_exists +from threading import Lock -from .. import helpers as utils +from .. import helpers, utils +from ..tl.types import ( + InputPeerUser, InputPeerChat, InputPeerChannel, + PeerUser, PeerChat, PeerChannel +) class Session: @@ -51,7 +55,7 @@ class Session: # Cross-thread safety self._lock = Lock() - self.id = utils.generate_random_long(signed=False) + self.id = helpers.generate_random_long(signed=False) self._sequence = 0 self.time_offset = 0 self._last_msg_id = 0 # Long @@ -62,6 +66,8 @@ class Session: self.auth_key = None self.layer = 0 self.salt = 0 # Unsigned long + self._input_entities = {} # {marked_id: hash} + self._entities_lock = Lock() def save(self): """Saves the current session object as session_user_id.session""" @@ -74,7 +80,8 @@ class Session: 'server_address': self.server_address, 'auth_key_data': b64encode(self.auth_key.key).decode('ascii') - if self.auth_key else None + if self.auth_key else None, + 'entities': list(self._input_entities.items()) }, file) def delete(self): @@ -122,6 +129,9 @@ class Session: key = b64decode(data['auth_key_data']) result.auth_key = AuthKey(data=key) + for e_mid, e_hash in data.get('entities', []): + result._input_entities[e_mid] = e_hash + except (json.decoder.JSONDecodeError, UnicodeDecodeError): pass @@ -164,3 +174,43 @@ class Session: now = int(time.time()) correct = correct_msg_id >> 32 self.time_offset = correct - now + + def add_entities(self, entities): + """Adds new input entities to the local database of them. + Unknown types will be ignored. + """ + if not entities: + return False + + new = {} + for e in entities: + try: + p = utils.get_input_peer(e) + new[utils.get_peer_id(p, add_mark=True)] = \ + getattr(p, 'access_hash', 0) # chats won't have hash + except ValueError: + pass + + with self._entities_lock: + before = len(self._input_entities) + self._input_entities.update(new) + return len(self._input_entities) != before + + def get_input_entity(self, peer): + """Gets an input entity known its Peer or a marked ID, + or raises KeyError if not found/invalid. + """ + if not isinstance(peer, int): + peer = utils.get_peer_id(peer, add_mark=True) + + entity_hash = self._input_entities[peer] + entity_id, peer_class = utils.resolve_id(peer) + + if peer_class == PeerUser: + return InputPeerUser(entity_id, entity_hash) + if peer_class == PeerChat: + return InputPeerChat(entity_id) + if peer_class == PeerChannel: + return InputPeerChannel(entity_id, entity_hash) + + raise KeyError() diff --git a/telethon/utils.py b/telethon/utils.py index 273dc962..f6e6373a 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -2,6 +2,7 @@ Utilities for working with the Telegram API itself (such as handy methods to convert between an entity like an User, Chat, etc. into its Input version) """ +import math from mimetypes import add_type, guess_extension from .tl import TLObject @@ -299,6 +300,41 @@ def get_input_media(media, user_caption=None, is_photo=False): _raise_cast_fail(media, 'InputMedia') +def get_peer_id(peer, add_mark=False): + """Finds the ID of the given peer, and optionally converts it to + the "bot api" format if 'add_mark' is set to True. + """ + if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser): + return peer.user_id + else: + if isinstance(peer, PeerChat) or isinstance(peer, InputPeerChat): + if not add_mark: + return peer.chat_id + + # Chats are marked by turning them into negative numbers + return -peer.chat_id + elif isinstance(peer, PeerChannel) or isinstance(peer, InputPeerChannel): + if not add_mark: + return peer.channel_id + + # Prepend -100 through math tricks (.to_supergroup() on Madeline) + i = peer.channel_id # IDs will be strictly positive -> log works + return -(i + pow(10, math.floor(math.log10(i) + 3))) + + _raise_cast_fail(peer, 'int') + + +def resolve_id(marked_id): + """Given a marked ID, returns the original ID and its Peer type""" + if marked_id >= 0: + return marked_id, PeerUser + + if str(marked_id).startswith('-100'): + return int(str(marked_id)[4:]), PeerChannel + + return -marked_id, PeerChat + + def find_user_or_chat(peer, users, chats): """Finds the corresponding user or chat given a peer. Returns None if it was not found"""