From aef96f1b6898a5a4b48b3a6943eb574ab5df1052 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Dec 2017 00:50:09 +0100 Subject: [PATCH] Remove custom EntityDatabase and use sqlite3 instead There are still a few things to change, like cleaning up the code and actually caching the entities as a whole (currently, although the username/phone/name can be used to fetch their input version which is an improvement, their full version needs to be re-fetched. Maybe it's a good thing though?) --- telethon/telegram_client.py | 65 ++++----- telethon/tl/entity_database.py | 252 --------------------------------- telethon/tl/session.py | 137 ++++++++++++++++-- telethon/utils.py | 30 ++++ 4 files changed, 181 insertions(+), 303 deletions(-) delete mode 100644 telethon/tl/entity_database.py diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index 32ade1a9..5d09ee2c 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -19,7 +19,6 @@ from .errors import ( from .network import ConnectionMode from .tl import TLObject from .tl.custom import Draft, Dialog -from .tl.entity_database import EntityDatabase from .tl.functions.account import ( GetPasswordRequest ) @@ -144,7 +143,7 @@ class TelegramClient(TelegramBareClient): :return auth.SentCode: Information about the result of the request. """ - phone = EntityDatabase.parse_phone(phone) or self._phone + phone = utils.parse_phone(phone) or self._phone if not self._phone_code_hash: result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) @@ -188,7 +187,7 @@ class TelegramClient(TelegramBareClient): if phone and not code: return self.send_code_request(phone) elif code: - phone = EntityDatabase.parse_phone(phone) or self._phone + phone = utils.parse_phone(phone) or self._phone phone_code_hash = phone_code_hash or self._phone_code_hash if not phone: raise ValueError( @@ -1009,12 +1008,8 @@ class TelegramClient(TelegramBareClient): may be out of date. :return: """ - if not force_fetch: - # Try to use cache unless we want to force a fetch - try: - return self.session.entities[entity] - except KeyError: - pass + # TODO Actually cache {id: entities} again + # >>> if not force_fetch: reuse cached if isinstance(entity, int) or ( isinstance(entity, TLObject) and @@ -1022,36 +1017,38 @@ class TelegramClient(TelegramBareClient): type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): ie = self.get_input_entity(entity) if isinstance(ie, InputPeerUser): - self(GetUsersRequest([ie])) + return self(GetUsersRequest([ie]))[0] elif isinstance(ie, InputPeerChat): - self(GetChatsRequest([ie.chat_id])) + return self(GetChatsRequest([ie.chat_id])).chats[0] elif isinstance(ie, InputPeerChannel): - self(GetChannelsRequest([ie])) - try: - # session.process_entities has been called in the MtProtoSender - # with the result of these calls, so they should now be on the - # entities database. - return self.session.entities[ie] - except KeyError: - pass + return self(GetChannelsRequest([ie])).chats[0] if isinstance(entity, str): - return self._get_entity_from_string(entity) + # TODO This probably can be done better... + invite = self._load_entity_from_string(entity) + if invite: + return invite + return self.get_entity(self.session.get_input_entity(entity)) raise ValueError( 'Cannot turn "{}" into any entity (user or chat)'.format(entity) ) - def _get_entity_from_string(self, string): - """Gets an entity from the given string, which may be a phone or - an username, and processes all the found entities on the session. + def _load_entity_from_string(self, string): """ - phone = EntityDatabase.parse_phone(string) + Loads an entity from the given string, which may be a phone or + an username, and processes all the found entities on the session. + + This method will effectively add the found users to the session + database, so it can be queried later. + + May return a channel or chat if the string was an invite. + """ + phone = utils.parse_phone(string) if phone: - entity = phone self(GetContactsRequest(0)) else: - entity, is_join_chat = EntityDatabase.parse_username(string) + entity, is_join_chat = utils.parse_username(string) if is_join_chat: invite = self(CheckChatInviteRequest(entity)) if isinstance(invite, ChatInvite): @@ -1063,13 +1060,6 @@ class TelegramClient(TelegramBareClient): return invite.chat else: self(ResolveUsernameRequest(entity)) - # MtProtoSender will call .process_entities on the requests made - try: - return self.session.entities[entity] - except KeyError: - raise ValueError( - 'Could not find user with username {}'.format(entity) - ) def get_input_entity(self, peer): """ @@ -1092,12 +1082,15 @@ class TelegramClient(TelegramBareClient): """ try: # First try to get the entity from cache, otherwise figure it out - return self.session.entities.get_input_entity(peer) + return self.session.get_input_entity(peer) except KeyError: pass if isinstance(peer, str): - return utils.get_input_peer(self._get_entity_from_string(peer)) + invite = self._load_entity_from_string(peer) + if invite: + return utils.get_input_peer(invite) + return self.session.get_input_entity(peer) is_peer = False if isinstance(peer, int): @@ -1130,7 +1123,7 @@ class TelegramClient(TelegramBareClient): exclude_pinned=True )) try: - return self.session.entities.get_input_entity(peer) + return self.session.get_input_entity(peer) except KeyError: pass diff --git a/telethon/tl/entity_database.py b/telethon/tl/entity_database.py deleted file mode 100644 index 9002ebd8..00000000 --- a/telethon/tl/entity_database.py +++ /dev/null @@ -1,252 +0,0 @@ -import re -from threading import Lock - -from ..tl import TLObject -from ..tl.types import ( - User, Chat, Channel, PeerUser, PeerChat, PeerChannel, - InputPeerUser, InputPeerChat, InputPeerChannel -) -from .. import utils # Keep this line the last to maybe fix #357 - - -USERNAME_RE = re.compile( - r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?' -) - - -class EntityDatabase: - def __init__(self, input_list=None, enabled=True, enabled_full=True): - """Creates a new entity database with an initial load of "Input" - entities, if any. - - If 'enabled', input entities will be saved. The whole entity - will be saved if both 'enabled' and 'enabled_full' are True. - """ - self.enabled = enabled - self.enabled_full = enabled_full - - self._lock = Lock() - self._entities = {} # marked_id: user|chat|channel - - if input_list: - # TODO For compatibility reasons some sessions were saved with - # 'access_hash': null in the JSON session file. Drop these, as - # it means we don't have access to such InputPeers. Issue #354. - self._input_entities = { - k: v for k, v in input_list if v is not None - } - else: - self._input_entities = {} # marked_id: hash - - # TODO Allow disabling some extra mappings - self._username_id = {} # username: marked_id - self._phone_id = {} # phone: marked_id - - def process(self, tlobject): - """Processes all the found entities on the given TLObject, - unless .enabled is False. - - Returns True if new input entities were added. - """ - if not self.enabled: - return False - - # Save all input entities we know of - if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'): - # This may be a list of users already for instance - return self.expand(tlobject) - - entities = [] - if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'): - entities.extend(tlobject.chats) - if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'): - entities.extend(tlobject.users) - - return self.expand(entities) - - def expand(self, entities): - """Adds new input entities to the local database unconditionally. - Unknown types will be ignored. - """ - if not entities or not self.enabled: - return False - - new = [] # Array of entities (User, Chat, or Channel) - new_input = {} # Dictionary of {entity_marked_id: access_hash} - for e in entities: - if not isinstance(e, TLObject): - continue - - try: - p = utils.get_input_peer(e, allow_self=False) - marked_id = utils.get_peer_id(p, add_mark=True) - - has_hash = False - if isinstance(p, InputPeerChat): - # Chats don't have a hash - new_input[marked_id] = 0 - has_hash = True - elif p.access_hash: - # Some users and channels seem to be returned without - # an 'access_hash', meaning Telegram doesn't want you - # to access them. This is the reason behind ensuring - # that the 'access_hash' is non-zero. See issue #354. - new_input[marked_id] = p.access_hash - has_hash = True - - if self.enabled_full and has_hash: - if isinstance(e, (User, Chat, Channel)): - new.append(e) - except ValueError: - pass - - with self._lock: - before = len(self._input_entities) - self._input_entities.update(new_input) - for e in new: - self._add_full_entity(e) - return len(self._input_entities) != before - - def _add_full_entity(self, entity): - """Adds a "full" entity (User, Chat or Channel, not "Input*"), - despite the value of self.enabled and self.enabled_full. - - Not to be confused with UserFull, ChatFull, or ChannelFull, - "full" means simply not "Input*". - """ - marked_id = utils.get_peer_id( - utils.get_input_peer(entity, allow_self=False), add_mark=True - ) - try: - old_entity = self._entities[marked_id] - old_entity.__dict__.update(entity.__dict__) # Keep old references - - # Update must delete old username and phone - username = getattr(old_entity, 'username', None) - if username: - del self._username_id[username.lower()] - - phone = getattr(old_entity, 'phone', None) - if phone: - del self._phone_id[phone] - except KeyError: - # Add new entity - self._entities[marked_id] = entity - - # Always update username or phone if any - username = getattr(entity, 'username', None) - if username: - self._username_id[username.lower()] = marked_id - - phone = getattr(entity, 'phone', None) - if phone: - self._phone_id[phone] = marked_id - - def _parse_key(self, key): - """Parses the given string, integer or TLObject key into a - marked user ID ready for use on self._entities. - - If a callable key is given, the entity will be passed to the - function, and if it returns a true-like value, the marked ID - for such entity will be returned. - - Raises ValueError if it cannot be parsed. - """ - if isinstance(key, str): - phone = EntityDatabase.parse_phone(key) - try: - if phone: - return self._phone_id[phone] - else: - username, _ = EntityDatabase.parse_username(key) - return self._username_id[username.lower()] - except KeyError as e: - raise ValueError() from e - - if isinstance(key, int): - return key # normal IDs are assumed users - - if isinstance(key, TLObject): - return utils.get_peer_id(key, add_mark=True) - - if callable(key): - for k, v in self._entities.items(): - if key(v): - return k - - raise ValueError() - - def __getitem__(self, key): - """See the ._parse_key() docstring for possible values of the key""" - try: - return self._entities[self._parse_key(key)] - except (ValueError, KeyError) as e: - raise KeyError(key) from e - - def __delitem__(self, key): - try: - old = self._entities.pop(self._parse_key(key)) - # Try removing the username and phone (if pop didn't fail), - # since the entity may have no username or phone, just ignore - # errors. It should be there if we popped the entity correctly. - try: - del self._username_id[getattr(old, 'username', None)] - except KeyError: - pass - - try: - del self._phone_id[getattr(old, 'phone', None)] - except KeyError: - pass - - except (ValueError, KeyError) as e: - raise KeyError(key) from e - - @staticmethod - def parse_phone(phone): - """Parses the given phone, or returns None if it's invalid""" - if isinstance(phone, int): - return str(phone) - else: - phone = re.sub(r'[+()\s-]', '', str(phone)) - if phone.isdigit(): - return phone - - @staticmethod - def parse_username(username): - """Parses the given username or channel access hash, given - a string, username or URL. Returns a tuple consisting of - both the stripped username and whether it is a joinchat/ hash. - """ - username = username.strip() - m = USERNAME_RE.match(username) - if m: - return username[m.end():], bool(m.group(1)) - else: - return username, False - - def get_input_entity(self, peer): - try: - i = utils.get_peer_id(peer, add_mark=True) - h = self._input_entities[i] # we store the IDs marked - i, k = utils.resolve_id(i) # removes the mark and returns kind - - if k == PeerUser: - return InputPeerUser(i, h) - elif k == PeerChat: - return InputPeerChat(i) - elif k == PeerChannel: - return InputPeerChannel(i, h) - - except ValueError as e: - raise KeyError(peer) from e - raise KeyError(peer) - - def get_input_list(self): - return list(self._input_entities.items()) - - def clear(self, target=None): - if target is None: - self._entities.clear() - else: - del self[target] diff --git a/telethon/tl/session.py b/telethon/tl/session.py index ff4631f8..12bc3937 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -8,8 +8,12 @@ from base64 import b64decode from os.path import isfile as file_exists from threading import Lock -from .entity_database import EntityDatabase -from .. import helpers +from .. import utils, helpers +from ..tl import TLObject +from ..tl.types import ( + PeerUser, PeerChat, PeerChannel, + InputPeerUser, InputPeerChat, InputPeerChannel +) EXTENSION = '.session' CURRENT_VERSION = 1 # database version @@ -75,10 +79,9 @@ class Session: self._auth_key = None self._layer = 0 self._salt = 0 # Signed long - self.entities = EntityDatabase() # Known and cached entities # Migrating from .json -> SQL - self._check_migrate_json() + entities = self._check_migrate_json() self._conn = sqlite3.connect(self.filename, check_same_thread=False) c = self._conn.cursor() @@ -114,14 +117,20 @@ class Session: ) c.execute( """create table entities ( - id integer, - hash integer, + id integer primary key, + hash integer not null, username text, phone integer, name text )""" ) c.execute("insert into version values (1)") + # Migrating from JSON -> new table and may have entities + if entities: + c.executemany( + 'insert or replace into entities values (?,?,?,?,?)', + entities + ) c.close() self.save() @@ -130,6 +139,8 @@ class Session: try: with open(self.filename, encoding='utf-8') as f: data = json.load(f) + self.delete() # Delete JSON file to create database + self._port = data.get('port', self._port) self._salt = data.get('salt', self._salt) # Keep while migrating from unsigned to signed salt @@ -146,10 +157,12 @@ class Session: key = b64decode(data['auth_key_data']) self._auth_key = AuthKey(data=key) - self.entities = EntityDatabase(data.get('entities', [])) - self.delete() # Delete JSON file to create database + rows = [] + for p_id, p_hash in data.get('entities', []): + rows.append((p_id, p_hash, None, None, None)) + return rows except (UnicodeDecodeError, json.decoder.JSONDecodeError): - pass + return [] # No entities def _upgrade_database(self, old): pass @@ -275,9 +288,103 @@ class Session: correct = correct_msg_id >> 32 self.time_offset = correct - now - def process_entities(self, tlobject): - try: - if self.entities.process(tlobject): - self.save() # Save if any new entities got added - except: - pass + # Entity processing + + def process_entities(self, tlo): + """Processes all the found entities on the given TLObject, + unless .enabled is False. + + Returns True if new input entities were added. + """ + if not self.save_entities: + return + + if not isinstance(tlo, TLObject) and hasattr(tlo, '__iter__'): + # This may be a list of users already for instance + entities = tlo + else: + entities = [] + if hasattr(tlo, 'chats') and hasattr(tlo.chats, '__iter__'): + entities.extend(tlo.chats) + if hasattr(tlo, 'users') and hasattr(tlo.users, '__iter__'): + entities.extend(tlo.users) + if not entities: + return + + rows = [] # Rows to add (id, hash, username, phone, name) + for e in entities: + if not isinstance(e, TLObject): + continue + try: + p = utils.get_input_peer(e, allow_self=False) + marked_id = utils.get_peer_id(p, add_mark=True) + + p_hash = None + if isinstance(p, InputPeerChat): + p_hash = 0 + elif p.access_hash: + # Some users and channels seem to be returned without + # an 'access_hash', meaning Telegram doesn't want you + # to access them. This is the reason behind ensuring + # that the 'access_hash' is non-zero. See issue #354. + p_hash = p.access_hash + + if p_hash is not None: + username = getattr(e, 'username', None) + phone = getattr(e, 'phone', None) + name = utils.get_display_name(e) or None + rows.append((marked_id, p_hash, username, phone, name)) + except ValueError: + pass + if not rows: + return + + with self._db_lock: + self._conn.executemany( + 'insert or replace into entities values (?,?,?,?,?)', rows + ) + self.save() + + def get_input_entity(self, key): + """Parses the given string, integer or TLObject key into a + marked entity ID, which is then used to fetch the hash + from the database. + + If a callable key is given, every row will be fetched, + and passed as a tuple to a function, that should return + a true-like value when the desired row is found. + + Raises ValueError if it cannot be found. + """ + c = self._conn.cursor() + if isinstance(key, str): + phone = utils.parse_phone(key) + if phone: + c.execute('select id, hash from entities where phone=?', + (phone,)) + else: + username, _ = utils.parse_username(key) + c.execute('select id, hash from entities where username=?', + (username,)) + + if isinstance(key, TLObject): + # crc32(b'InputPeer') and crc32(b'Peer') + if type(key).SUBCLASS_OF_ID == 0xc91c90b6: + return key + key = utils.get_peer_id(key, add_mark=True) + + if isinstance(key, int): + c.execute('select id, hash from entities where id=?', (key,)) + + result = c.fetchone() + if result: + i, h = result # unpack resulting tuple + i, k = utils.resolve_id(i) # removes the mark and returns kind + if k == PeerUser: + return InputPeerUser(i, h) + elif k == PeerChat: + return InputPeerChat(i) + elif k == PeerChannel: + return InputPeerChannel(i, h) + else: + raise ValueError('Could not find input entity with key ', key) diff --git a/telethon/utils.py b/telethon/utils.py index 5e92b13d..04970632 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -5,6 +5,8 @@ to convert between an entity like an User, Chat, etc. into its Input version) import math from mimetypes import add_type, guess_extension +import re + from .tl import TLObject from .tl.types import ( Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull, @@ -24,6 +26,11 @@ from .tl.types import ( ) +USERNAME_RE = re.compile( + r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?' +) + + def get_display_name(entity): """Gets the input peer for the given "entity" (user, chat or channel) Returns None if it was not found""" @@ -305,6 +312,29 @@ def get_input_media(media, user_caption=None, is_photo=False): _raise_cast_fail(media, 'InputMedia') +def parse_phone(phone): + """Parses the given phone, or returns None if it's invalid""" + if isinstance(phone, int): + return str(phone) + else: + phone = re.sub(r'[+()\s-]', '', str(phone)) + if phone.isdigit(): + return phone + + +def parse_username(username): + """Parses the given username or channel access hash, given + a string, username or URL. Returns a tuple consisting of + both the stripped username and whether it is a joinchat/ hash. + """ + username = username.strip() + m = USERNAME_RE.match(username) + if m: + return username[m.end():], bool(m.group(1)) + else: + return username, False + + def get_peer_id(peer, add_mark=False): """Finds the ID of the given peer, and optionally converts it to the "bot api" format if 'add_mark' is set to True.