Support getting any entity by just their positive ID

This commit is contained in:
Lonami Exo 2018-03-08 10:05:40 +01:00
parent d3d190f36e
commit ce0dee63b1
4 changed files with 74 additions and 35 deletions

View File

@ -1,9 +1,8 @@
from enum import Enum from enum import Enum
from .. import utils
from .abstract import Session from .abstract import Session
from .. import utils
from ..tl import TLObject from ..tl import TLObject
from ..tl.types import ( from ..tl.types import (
PeerUser, PeerChat, PeerChannel, PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel,
@ -148,10 +147,19 @@ class MemorySession(Session):
except StopIteration: except StopIteration:
pass pass
def get_entity_rows_by_id(self, id): def get_entity_rows_by_id(self, id, exact=True):
try: try:
if exact:
return next((id, hash) for found_id, hash, _, _, _ return next((id, hash) for found_id, hash, _, _, _
in self._entities if found_id == id) in self._entities if found_id == id)
else:
ids = (
utils.get_peer_id(PeerUser(id)),
utils.get_peer_id(PeerChat(id)),
utils.get_peer_id(PeerChannel(id))
)
return next((id, hash) for found_id, hash, _, _, _
in self._entities if found_id in ids)
except StopIteration: except StopIteration:
pass pass
@ -167,6 +175,9 @@ class MemorySession(Session):
# Not a TLObject or can't be cast into InputPeer # Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject): if isinstance(key, TLObject):
key = utils.get_peer_id(key) key = utils.get_peer_id(key)
exact = True
else:
exact = False
result = None result = None
if isinstance(key, str): if isinstance(key, str):
@ -178,8 +189,8 @@ class MemorySession(Session):
if username: if username:
result = self.get_entity_rows_by_username(username) result = self.get_entity_rows_by_username(username)
if isinstance(key, int): elif isinstance(key, int):
result = self.get_entity_rows_by_id(key) result = self.get_entity_rows_by_id(key, exact)
if not result and isinstance(key, str): if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key) result = self.get_entity_rows_by_name(key)

View File

@ -4,12 +4,13 @@ try:
import sqlalchemy as sql import sqlalchemy as sql
except ImportError: except ImportError:
sql = None sql = None
pass
from ..crypto import AuthKey
from ..tl.types import InputPhoto, InputDocument
from .memory import MemorySession, _SentFileType from .memory import MemorySession, _SentFileType
from .. import utils
from ..crypto import AuthKey
from ..tl.types import (
InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel
)
LATEST_VERSION = 1 LATEST_VERSION = 1
@ -201,8 +202,18 @@ class AlchemySession(MemorySession):
self.Entity.name == key).one_or_none() self.Entity.name == key).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_entity_rows_by_id(self, key): def get_entity_rows_by_id(self, key, exact=True):
row = self._db_query(self.Entity, self.Entity.id == key).one_or_none() if exact:
query = self._db_query(self.Entity, self.Entity.id == key)
else:
ids = (
utils.get_peer_id(PeerUser(key)),
utils.get_peer_id(PeerChat(key)),
utils.get_peer_id(PeerChannel(key))
)
query = self._db_query(self.Entity, self.Entity.id in ids)
row = query.one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_file(self, md5_digest, file_size, cls): def get_file(self, md5_digest, file_size, cls):

View File

@ -6,9 +6,10 @@ from os.path import isfile as file_exists
from threading import Lock, RLock from threading import Lock, RLock
from .memory import MemorySession, _SentFileType from .memory import MemorySession, _SentFileType
from .. import utils
from ..crypto import AuthKey from ..crypto import AuthKey
from ..tl.types import ( from ..tl.types import (
InputPhoto, InputDocument InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel
) )
EXTENSION = '.session' EXTENSION = '.session'
@ -282,9 +283,19 @@ class SQLiteSession(MemorySession):
return self._fetchone_entity( return self._fetchone_entity(
'select id, hash from entities where name=?', (name,)) 'select id, hash from entities where name=?', (name,))
def get_entity_rows_by_id(self, id): def get_entity_rows_by_id(self, id, exact=True):
if exact:
return self._fetchone_entity( return self._fetchone_entity(
'select id, hash from entities where id=?', (id,)) 'select id, hash from entities where id=?', (id,))
else:
ids = (
utils.get_peer_id(PeerUser(id)),
utils.get_peer_id(PeerChat(id)),
utils.get_peer_id(PeerChannel(id))
)
return self._fetchone_entity(
'select id, hash from entities where id in (?,?,?)', ids
)
# File processing # File processing

View File

@ -2272,23 +2272,21 @@ class TelegramClient(TelegramBareClient):
return InputPeerSelf() return InputPeerSelf()
return utils.get_input_peer(self._get_entity_from_string(peer)) return utils.get_input_peer(self._get_entity_from_string(peer))
if isinstance(peer, int): if not isinstance(peer, int):
peer, kind = utils.resolve_id(peer)
peer = kind(peer)
try: try:
is_peer = peer.SUBCLASS_OF_ID == 0x2d45687 # crc32(b'Peer') if peer.SUBCLASS_OF_ID != 0x2d45687: # crc32(b'Peer')
if not is_peer:
return utils.get_input_peer(peer) return utils.get_input_peer(peer)
except (AttributeError, TypeError): except (AttributeError, TypeError):
is_peer = False peer = None
if not is_peer: if not peer:
raise TypeError( raise TypeError(
'Cannot turn "{}" into an input entity.'.format(peer) 'Cannot turn "{}" into an input entity.'.format(peer)
) )
# Not found, look in the dialogs with the hope to find it. # Add the mark to the peers if the user passed a Peer (not an int)
# Look in the dialogs with the hope to find it.
mark = not isinstance(peer, int)
target_id = utils.get_peer_id(peer) target_id = utils.get_peer_id(peer)
req = GetDialogsRequest( req = GetDialogsRequest(
offset_date=None, offset_date=None,
@ -2299,12 +2297,20 @@ class TelegramClient(TelegramBareClient):
while True: while True:
result = self(req) result = self(req)
entities = {} entities = {}
if mark:
for x in itertools.chain(result.users, result.chats): for x in itertools.chain(result.users, result.chats):
x_id = utils.get_peer_id(x) x_id = utils.get_peer_id(x)
if x_id == target_id: if x_id == target_id:
return utils.get_input_peer(x) return utils.get_input_peer(x)
else: else:
entities[x_id] = x entities[x_id] = x
else:
for x in itertools.chain(result.users, result.chats):
if x.id == target_id:
return utils.get_input_peer(x)
else:
entities[utils.get_peer_id(x)] = x
if len(result.dialogs) < req.limit: if len(result.dialogs) < req.limit:
break break