From ce0dee63b1609c1a848f837e47977aaac6945038 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 8 Mar 2018 10:05:40 +0100 Subject: [PATCH] Support getting any entity by just their positive ID --- telethon/sessions/memory.py | 25 ++++++++++++++------ telethon/sessions/sqlalchemy.py | 23 +++++++++++++----- telethon/sessions/sqlite.py | 19 +++++++++++---- telethon/telegram_client.py | 42 +++++++++++++++++++-------------- 4 files changed, 74 insertions(+), 35 deletions(-) diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 4d7e6778..43ddde4b 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -1,9 +1,8 @@ from enum import Enum -from .. import utils from .abstract import Session +from .. import utils from ..tl import TLObject - from ..tl.types import ( PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, @@ -148,10 +147,19 @@ class MemorySession(Session): except StopIteration: pass - def get_entity_rows_by_id(self, id): + def get_entity_rows_by_id(self, id, exact=True): try: - return next((id, hash) for found_id, hash, _, _, _ - in self._entities if found_id == id) + if exact: + return next((id, hash) for found_id, hash, _, _, _ + in self._entities if found_id == id) + else: + ids = ( + utils.get_peer_id(PeerUser(id)), + utils.get_peer_id(PeerChat(id)), + utils.get_peer_id(PeerChannel(id)) + ) + return next((id, hash) for found_id, hash, _, _, _ + in self._entities if found_id in ids) except StopIteration: pass @@ -167,6 +175,9 @@ class MemorySession(Session): # Not a TLObject or can't be cast into InputPeer if isinstance(key, TLObject): key = utils.get_peer_id(key) + exact = True + else: + exact = False result = None if isinstance(key, str): @@ -178,8 +189,8 @@ class MemorySession(Session): if username: result = self.get_entity_rows_by_username(username) - if isinstance(key, int): - result = self.get_entity_rows_by_id(key) + elif isinstance(key, int): + result = self.get_entity_rows_by_id(key, exact) if not result and isinstance(key, str): result = self.get_entity_rows_by_name(key) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index dc9040a1..ceaa0847 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -4,12 +4,13 @@ try: import sqlalchemy as sql except ImportError: sql = None - pass - -from ..crypto import AuthKey -from ..tl.types import InputPhoto, InputDocument from .memory import MemorySession, _SentFileType +from .. import utils +from ..crypto import AuthKey +from ..tl.types import ( + InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel +) LATEST_VERSION = 1 @@ -201,8 +202,18 @@ class AlchemySession(MemorySession): self.Entity.name == key).one_or_none() return row.id, row.hash if row else None - def get_entity_rows_by_id(self, key): - row = self._db_query(self.Entity, self.Entity.id == key).one_or_none() + def get_entity_rows_by_id(self, key, exact=True): + if exact: + query = self._db_query(self.Entity, self.Entity.id == key) + else: + ids = ( + utils.get_peer_id(PeerUser(key)), + utils.get_peer_id(PeerChat(key)), + utils.get_peer_id(PeerChannel(key)) + ) + query = self._db_query(self.Entity, self.Entity.id in ids) + + row = query.one_or_none() return row.id, row.hash if row else None def get_file(self, md5_digest, file_size, cls): diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index c764cd21..e9a4a723 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -6,9 +6,10 @@ from os.path import isfile as file_exists from threading import Lock, RLock from .memory import MemorySession, _SentFileType +from .. import utils from ..crypto import AuthKey from ..tl.types import ( - InputPhoto, InputDocument + InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel ) EXTENSION = '.session' @@ -282,9 +283,19 @@ class SQLiteSession(MemorySession): return self._fetchone_entity( 'select id, hash from entities where name=?', (name,)) - def get_entity_rows_by_id(self, id): - return self._fetchone_entity( - 'select id, hash from entities where id=?', (id,)) + def get_entity_rows_by_id(self, id, exact=True): + if exact: + return self._fetchone_entity( + 'select id, hash from entities where id=?', (id,)) + else: + ids = ( + utils.get_peer_id(PeerUser(id)), + utils.get_peer_id(PeerChat(id)), + utils.get_peer_id(PeerChannel(id)) + ) + return self._fetchone_entity( + 'select id, hash from entities where id in (?,?,?)', ids + ) # File processing diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index e1914eb5..87f57945 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -2272,23 +2272,21 @@ class TelegramClient(TelegramBareClient): return InputPeerSelf() return utils.get_input_peer(self._get_entity_from_string(peer)) - if isinstance(peer, int): - peer, kind = utils.resolve_id(peer) - peer = kind(peer) + if not isinstance(peer, int): + try: + if peer.SUBCLASS_OF_ID != 0x2d45687: # crc32(b'Peer') + return utils.get_input_peer(peer) + except (AttributeError, TypeError): + peer = None - try: - is_peer = peer.SUBCLASS_OF_ID == 0x2d45687 # crc32(b'Peer') - if not is_peer: - return utils.get_input_peer(peer) - except (AttributeError, TypeError): - is_peer = False - - if not is_peer: + if not peer: raise TypeError( 'Cannot turn "{}" into an input entity.'.format(peer) ) - # Not found, look in the dialogs with the hope to find it. + # Add the mark to the peers if the user passed a Peer (not an int) + # Look in the dialogs with the hope to find it. + mark = not isinstance(peer, int) target_id = utils.get_peer_id(peer) req = GetDialogsRequest( offset_date=None, @@ -2299,12 +2297,20 @@ class TelegramClient(TelegramBareClient): while True: result = self(req) entities = {} - for x in itertools.chain(result.users, result.chats): - x_id = utils.get_peer_id(x) - if x_id == target_id: - return utils.get_input_peer(x) - else: - entities[x_id] = x + if mark: + for x in itertools.chain(result.users, result.chats): + x_id = utils.get_peer_id(x) + if x_id == target_id: + return utils.get_input_peer(x) + else: + entities[x_id] = x + else: + for x in itertools.chain(result.users, result.chats): + if x.id == target_id: + return utils.get_input_peer(x) + else: + entities[utils.get_peer_id(x)] = x + if len(result.dialogs) < req.limit: break