diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index f22d13e6..ff493f6c 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -39,6 +39,7 @@ from .update_state import UpdateState from .utils import get_appropriated_part_size +DEFAULT_DC_ID = 4 DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]' DEFAULT_PORT = 443 @@ -92,7 +93,7 @@ class TelegramBareClient: # Determine what session object we have if isinstance(session, str) or session is None: - session = Session.try_load_or_create_new(session) + session = Session(session) elif not isinstance(session, Session): raise ValueError( 'The given session must be a str or a Session instance.' @@ -101,9 +102,11 @@ class TelegramBareClient: # ':' in session.server_address is True if it's an IPv6 address if (not session.server_address or (':' in session.server_address) != use_ipv6): - session.port = DEFAULT_PORT - session.server_address = \ - DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP + session.set_dc( + DEFAULT_DC_ID, + DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, + DEFAULT_PORT + ) self.session = session self.api_id = int(api_id) @@ -151,6 +154,10 @@ class TelegramBareClient: # Save whether the user is authorized here (a.k.a. logged in) self._authorized = None # None = We don't know yet + # The first request must be in invokeWithLayer(initConnection(X)). + # See https://core.telegram.org/api/invoking#saving-client-info. + self._first_request = True + # Uploaded files cache so subsequent calls are instant self._upload_cache = {} @@ -261,7 +268,7 @@ class TelegramBareClient: self._sender.disconnect() # TODO Shall we clear the _exported_sessions, or may be reused? - pass + self._first_request = True # On reconnect it will be first again def _reconnect(self, new_dc=None): """If 'new_dc' is not set, only a call to .connect() will be made @@ -290,8 +297,7 @@ class TelegramBareClient: dc = self._get_dc(new_dc) __log__.info('Reconnecting to new data center %s', dc) - self.session.server_address = dc.ip_address - self.session.port = dc.port + self.session.set_dc(dc.id, dc.ip_address, dc.port) # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self.session.auth_key = None @@ -366,8 +372,7 @@ class TelegramBareClient: # Construct this session with the connection parameters # (system version, device model...) from the current one. session = Session(self.session) - session.server_address = dc.ip_address - session.port = dc.port + session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[dc_id] = session __log__.info('Creating exported new client') @@ -393,8 +398,7 @@ class TelegramBareClient: if not session: dc = self._get_dc(cdn_redirect.dc_id, cdn=True) session = Session(self.session) - session.server_address = dc.ip_address - session.port = dc.port + session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session __log__.info('Creating new CDN client') @@ -495,10 +499,6 @@ class TelegramBareClient: invoke = __call__ def _invoke(self, sender, call_receive, update_state, *requests): - # We need to specify the new layer (by initializing a new - # connection) if it has changed from the latest known one. - init_connection = self.session.layer != LAYER - try: # Ensure that we start with no previous errors (i.e. resending) for x in requests: @@ -506,14 +506,11 @@ class TelegramBareClient: x.rpc_error = None if not self.session.auth_key: - # New key, we need to tell the server we're going to use - # the latest layer and initialize the connection doing so. __log__.info('Need to generate new auth key before invoking') self.session.auth_key, self.session.time_offset = \ authenticator.do_authentication(self._sender.connection) - init_connection = True - if init_connection: + if self._first_request: __log__.info('Initializing a new connection while invoking') if len(requests) == 1: requests = [self._wrap_init_connection(requests[0])] @@ -556,11 +553,8 @@ class TelegramBareClient: # User never called .connect(), so raise this error. raise - if init_connection: - # We initialized the connection successfully, even if - # a request had an RPC error we have invoked it fine. - self.session.layer = LAYER - self.session.save() + # Clear the flag if we got this far + self._first_request = False try: raise next(x.rpc_error for x in requests if x.rpc_error) diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index e0708bc9..8038f484 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -20,7 +20,6 @@ from .errors import ( from .network import ConnectionMode from .tl import TLObject from .tl.custom import Draft, Dialog -from .tl.entity_database import EntityDatabase from .tl.functions.account import ( GetPasswordRequest ) @@ -145,7 +144,7 @@ class TelegramClient(TelegramBareClient): :return auth.SentCode: Information about the result of the request. """ - phone = EntityDatabase.parse_phone(phone) or self._phone + phone = utils.parse_phone(phone) or self._phone if not self._phone_code_hash: result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) @@ -189,7 +188,7 @@ class TelegramClient(TelegramBareClient): if phone and not code: return self.send_code_request(phone) elif code: - phone = EntityDatabase.parse_phone(phone) or self._phone + phone = utils.parse_phone(phone) or self._phone phone_code_hash = phone_code_hash or self._phone_code_hash if not phone: raise ValueError( @@ -998,12 +997,12 @@ class TelegramClient(TelegramBareClient): # region Small utilities to make users' life easier - def get_entity(self, entity, force_fetch=False): + def get_entity(self, entity): """ Turns the given entity into a valid Telegram user or chat. :param entity: - The entity to be transformed. + The entity (or iterable of entities) to be transformed. If it's a string which can be converted to an integer or starts with '+' it will be resolved as if it were a phone number. @@ -1017,58 +1016,75 @@ class TelegramClient(TelegramBareClient): If the entity is neither, and it's not a TLObject, an error will be raised. - :param force_fetch: - If True, the entity cache is bypassed and the entity is fetched - again with an API call. Defaults to False to avoid unnecessary - calls, but since a cached version would be returned, the entity - may be out of date. - :return: + :return: User, Chat or Channel corresponding to the input entity. """ - if not force_fetch: - # Try to use cache unless we want to force a fetch - try: - return self.session.entities[entity] - except KeyError: - pass + if not isinstance(entity, str) and hasattr(entity, '__iter__'): + single = False + else: + single = True + entity = (entity,) - if isinstance(entity, int) or ( - isinstance(entity, TLObject) and - # crc32(b'InputPeer') and crc32(b'Peer') - type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): - ie = self.get_input_entity(entity) - if isinstance(ie, InputPeerUser): - self(GetUsersRequest([ie])) - elif isinstance(ie, InputPeerChat): - self(GetChatsRequest([ie.chat_id])) - elif isinstance(ie, InputPeerChannel): - self(GetChannelsRequest([ie])) - try: - # session.process_entities has been called in the MtProtoSender - # with the result of these calls, so they should now be on the - # entities database. - return self.session.entities[ie] - except KeyError: - pass + # Group input entities by string (resolve username), + # input users (get users), input chat (get chats) and + # input channels (get channels) to get the most entities + # in the less amount of calls possible. + inputs = [ + x if isinstance(x, str) else self.get_input_entity(x) + for x in entity + ] + users = [x for x in inputs if isinstance(x, InputPeerUser)] + chats = [x.chat_id for x in inputs if isinstance(x, InputPeerChat)] + channels = [x for x in inputs if isinstance(x, InputPeerChannel)] + if users: + # GetUsersRequest has a limit of 200 per call + tmp = [] + while users: + curr, users = users[:200], users[200:] + tmp.extend(self(GetUsersRequest(curr))) + users = tmp + if chats: # TODO Handle chats slice? + chats = self(GetChatsRequest(chats)).chats + if channels: + channels = self(GetChannelsRequest(channels)).chats - if isinstance(entity, str): - return self._get_entity_from_string(entity) + # Merge users, chats and channels into a single dictionary + id_entity = { + utils.get_peer_id(x, add_mark=True): x + for x in itertools.chain(users, chats, channels) + } - raise ValueError( - 'Cannot turn "{}" into any entity (user or chat)'.format(entity) - ) + # We could check saved usernames and put them into the users, + # chats and channels list from before. While this would reduce + # the amount of ResolveUsername calls, it would fail to catch + # username changes. + result = [ + self._get_entity_from_string(x) if isinstance(x, str) + else id_entity[utils.get_peer_id(x, add_mark=True)] + for x in inputs + ] + return result[0] if single else result def _get_entity_from_string(self, string): - """Gets an entity from the given string, which may be a phone or - an username, and processes all the found entities on the session. """ - phone = EntityDatabase.parse_phone(string) + Gets a full entity from the given string, which may be a phone or + an username, and processes all the found entities on the session. + The string may also be a user link, or a channel/chat invite link. + + This method has the side effect of adding the found users to the + session database, so it can be queried later without API calls, + if this option is enabled on the session. + + Returns the found entity. + """ + phone = utils.parse_phone(string) if phone: - entity = phone - self(GetContactsRequest(0)) + for user in self(GetContactsRequest(0)).users: + if user.phone == phone: + return user else: - entity, is_join_chat = EntityDatabase.parse_username(string) + string, is_join_chat = utils.parse_username(string) if is_join_chat: - invite = self(CheckChatInviteRequest(entity)) + invite = self(CheckChatInviteRequest(string)) if isinstance(invite, ChatInvite): # If it's an invite to a chat, the user must join before # for the link to be resolved and work, otherwise raise. @@ -1077,14 +1093,10 @@ class TelegramClient(TelegramBareClient): elif isinstance(invite, ChatInviteAlready): return invite.chat else: - self(ResolveUsernameRequest(entity)) - # MtProtoSender will call .process_entities on the requests made - try: - return self.session.entities[entity] - except KeyError: - raise ValueError( - 'Could not find user with username {}'.format(entity) - ) + result = self(ResolveUsernameRequest(string)) + for entity in itertools.chain(result.users, result.chats): + if entity.username.lower() == string: + return entity def get_input_entity(self, peer): """ @@ -1103,12 +1115,13 @@ class TelegramClient(TelegramBareClient): If in the end the access hash required for the peer was not found, a ValueError will be raised. - :return: + + :return: InputPeerUser, InputPeerChat or InputPeerChannel. """ try: # First try to get the entity from cache, otherwise figure it out - return self.session.entities.get_input_entity(peer) - except KeyError: + return self.session.get_input_entity(peer) + except ValueError: pass if isinstance(peer, str): @@ -1132,22 +1145,22 @@ class TelegramClient(TelegramBareClient): 'Cannot turn "{}" into an input entity.'.format(peer) ) - if self.session.save_entities: - # Not found, look in the latest dialogs. - # This is useful if for instance someone just sent a message but - # the updates didn't specify who, as this person or chat should - # be in the latest dialogs. - self(GetDialogsRequest( - offset_date=None, - offset_id=0, - offset_peer=InputPeerEmpty(), - limit=0, - exclude_pinned=True - )) - try: - return self.session.entities.get_input_entity(peer) - except KeyError: - pass + # Not found, look in the latest dialogs. + # This is useful if for instance someone just sent a message but + # the updates didn't specify who, as this person or chat should + # be in the latest dialogs. + dialogs = self(GetDialogsRequest( + offset_date=None, + offset_id=0, + offset_peer=InputPeerEmpty(), + limit=0, + exclude_pinned=True + )) + + target = utils.get_peer_id(peer, add_mark=True) + for entity in itertools.chain(dialogs.users, dialogs.chats): + if utils.get_peer_id(entity, add_mark=True) == target: + return utils.get_input_peer(entity) raise ValueError( 'Could not find the input entity corresponding to "{}".' diff --git a/telethon/tl/entity_database.py b/telethon/tl/entity_database.py deleted file mode 100644 index 9002ebd8..00000000 --- a/telethon/tl/entity_database.py +++ /dev/null @@ -1,252 +0,0 @@ -import re -from threading import Lock - -from ..tl import TLObject -from ..tl.types import ( - User, Chat, Channel, PeerUser, PeerChat, PeerChannel, - InputPeerUser, InputPeerChat, InputPeerChannel -) -from .. import utils # Keep this line the last to maybe fix #357 - - -USERNAME_RE = re.compile( - r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?' -) - - -class EntityDatabase: - def __init__(self, input_list=None, enabled=True, enabled_full=True): - """Creates a new entity database with an initial load of "Input" - entities, if any. - - If 'enabled', input entities will be saved. The whole entity - will be saved if both 'enabled' and 'enabled_full' are True. - """ - self.enabled = enabled - self.enabled_full = enabled_full - - self._lock = Lock() - self._entities = {} # marked_id: user|chat|channel - - if input_list: - # TODO For compatibility reasons some sessions were saved with - # 'access_hash': null in the JSON session file. Drop these, as - # it means we don't have access to such InputPeers. Issue #354. - self._input_entities = { - k: v for k, v in input_list if v is not None - } - else: - self._input_entities = {} # marked_id: hash - - # TODO Allow disabling some extra mappings - self._username_id = {} # username: marked_id - self._phone_id = {} # phone: marked_id - - def process(self, tlobject): - """Processes all the found entities on the given TLObject, - unless .enabled is False. - - Returns True if new input entities were added. - """ - if not self.enabled: - return False - - # Save all input entities we know of - if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'): - # This may be a list of users already for instance - return self.expand(tlobject) - - entities = [] - if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'): - entities.extend(tlobject.chats) - if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'): - entities.extend(tlobject.users) - - return self.expand(entities) - - def expand(self, entities): - """Adds new input entities to the local database unconditionally. - Unknown types will be ignored. - """ - if not entities or not self.enabled: - return False - - new = [] # Array of entities (User, Chat, or Channel) - new_input = {} # Dictionary of {entity_marked_id: access_hash} - for e in entities: - if not isinstance(e, TLObject): - continue - - try: - p = utils.get_input_peer(e, allow_self=False) - marked_id = utils.get_peer_id(p, add_mark=True) - - has_hash = False - if isinstance(p, InputPeerChat): - # Chats don't have a hash - new_input[marked_id] = 0 - has_hash = True - elif p.access_hash: - # Some users and channels seem to be returned without - # an 'access_hash', meaning Telegram doesn't want you - # to access them. This is the reason behind ensuring - # that the 'access_hash' is non-zero. See issue #354. - new_input[marked_id] = p.access_hash - has_hash = True - - if self.enabled_full and has_hash: - if isinstance(e, (User, Chat, Channel)): - new.append(e) - except ValueError: - pass - - with self._lock: - before = len(self._input_entities) - self._input_entities.update(new_input) - for e in new: - self._add_full_entity(e) - return len(self._input_entities) != before - - def _add_full_entity(self, entity): - """Adds a "full" entity (User, Chat or Channel, not "Input*"), - despite the value of self.enabled and self.enabled_full. - - Not to be confused with UserFull, ChatFull, or ChannelFull, - "full" means simply not "Input*". - """ - marked_id = utils.get_peer_id( - utils.get_input_peer(entity, allow_self=False), add_mark=True - ) - try: - old_entity = self._entities[marked_id] - old_entity.__dict__.update(entity.__dict__) # Keep old references - - # Update must delete old username and phone - username = getattr(old_entity, 'username', None) - if username: - del self._username_id[username.lower()] - - phone = getattr(old_entity, 'phone', None) - if phone: - del self._phone_id[phone] - except KeyError: - # Add new entity - self._entities[marked_id] = entity - - # Always update username or phone if any - username = getattr(entity, 'username', None) - if username: - self._username_id[username.lower()] = marked_id - - phone = getattr(entity, 'phone', None) - if phone: - self._phone_id[phone] = marked_id - - def _parse_key(self, key): - """Parses the given string, integer or TLObject key into a - marked user ID ready for use on self._entities. - - If a callable key is given, the entity will be passed to the - function, and if it returns a true-like value, the marked ID - for such entity will be returned. - - Raises ValueError if it cannot be parsed. - """ - if isinstance(key, str): - phone = EntityDatabase.parse_phone(key) - try: - if phone: - return self._phone_id[phone] - else: - username, _ = EntityDatabase.parse_username(key) - return self._username_id[username.lower()] - except KeyError as e: - raise ValueError() from e - - if isinstance(key, int): - return key # normal IDs are assumed users - - if isinstance(key, TLObject): - return utils.get_peer_id(key, add_mark=True) - - if callable(key): - for k, v in self._entities.items(): - if key(v): - return k - - raise ValueError() - - def __getitem__(self, key): - """See the ._parse_key() docstring for possible values of the key""" - try: - return self._entities[self._parse_key(key)] - except (ValueError, KeyError) as e: - raise KeyError(key) from e - - def __delitem__(self, key): - try: - old = self._entities.pop(self._parse_key(key)) - # Try removing the username and phone (if pop didn't fail), - # since the entity may have no username or phone, just ignore - # errors. It should be there if we popped the entity correctly. - try: - del self._username_id[getattr(old, 'username', None)] - except KeyError: - pass - - try: - del self._phone_id[getattr(old, 'phone', None)] - except KeyError: - pass - - except (ValueError, KeyError) as e: - raise KeyError(key) from e - - @staticmethod - def parse_phone(phone): - """Parses the given phone, or returns None if it's invalid""" - if isinstance(phone, int): - return str(phone) - else: - phone = re.sub(r'[+()\s-]', '', str(phone)) - if phone.isdigit(): - return phone - - @staticmethod - def parse_username(username): - """Parses the given username or channel access hash, given - a string, username or URL. Returns a tuple consisting of - both the stripped username and whether it is a joinchat/ hash. - """ - username = username.strip() - m = USERNAME_RE.match(username) - if m: - return username[m.end():], bool(m.group(1)) - else: - return username, False - - def get_input_entity(self, peer): - try: - i = utils.get_peer_id(peer, add_mark=True) - h = self._input_entities[i] # we store the IDs marked - i, k = utils.resolve_id(i) # removes the mark and returns kind - - if k == PeerUser: - return InputPeerUser(i, h) - elif k == PeerChat: - return InputPeerChat(i) - elif k == PeerChannel: - return InputPeerChannel(i, h) - - except ValueError as e: - raise KeyError(peer) from e - raise KeyError(peer) - - def get_input_list(self): - return list(self._input_entities.items()) - - def clear(self, target=None): - if target is None: - self._entities.clear() - else: - del self[target] diff --git a/telethon/tl/session.py b/telethon/tl/session.py index e530cc83..030b4e13 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -1,14 +1,22 @@ import json import os import platform +import sqlite3 import struct import time -from base64 import b64encode, b64decode +from base64 import b64decode from os.path import isfile as file_exists from threading import Lock -from .entity_database import EntityDatabase -from .. import helpers +from .. import utils, helpers +from ..tl import TLObject +from ..tl.types import ( + PeerUser, PeerChat, PeerChannel, + InputPeerUser, InputPeerChat, InputPeerChannel +) + +EXTENSION = '.session' +CURRENT_VERSION = 1 # database version class Session: @@ -19,33 +27,34 @@ class Session: If you think the session has been compromised, close all the sessions through an official Telegram client to revoke the authorization. """ - def __init__(self, session_user_id): + def __init__(self, session_id): """session_user_id should either be a string or another Session. Note that if another session is given, only parameters like those required to init a connection will be copied. """ # These values will NOT be saved - if isinstance(session_user_id, Session): - self.session_user_id = None - - # For connection purposes - session = session_user_id - self.device_model = session.device_model - self.system_version = session.system_version - self.app_version = session.app_version - self.lang_code = session.lang_code - self.system_lang_code = session.system_lang_code - self.lang_pack = session.lang_pack - self.report_errors = session.report_errors - self.save_entities = session.save_entities - self.flood_sleep_threshold = session.flood_sleep_threshold + self.filename = ':memory:' + # For connection purposes + if isinstance(session_id, Session): + self.device_model = session_id.device_model + self.system_version = session_id.system_version + self.app_version = session_id.app_version + self.lang_code = session_id.lang_code + self.system_lang_code = session_id.system_lang_code + self.lang_pack = session_id.lang_pack + self.report_errors = session_id.report_errors + self.save_entities = session_id.save_entities + self.flood_sleep_threshold = session_id.flood_sleep_threshold else: # str / None - self.session_user_id = session_user_id + if session_id: + self.filename = session_id + if not self.filename.endswith(EXTENSION): + self.filename += EXTENSION system = platform.uname() - self.device_model = system.system if system.system else 'Unknown' - self.system_version = system.release if system.release else '1.0' + self.device_model = system.system or 'Unknown' + self.system_version = system.release or '1.0' self.app_version = '1.0' # '0' will provoke error self.lang_code = 'en' self.system_lang_code = self.lang_code @@ -54,49 +63,149 @@ class Session: self.save_entities = True self.flood_sleep_threshold = 60 - # Cross-thread safety - self._seq_no_lock = Lock() - self._msg_id_lock = Lock() - self._save_lock = Lock() - self.id = helpers.generate_random_long(signed=True) self._sequence = 0 self.time_offset = 0 self._last_msg_id = 0 # Long + self.salt = 0 # Long + + # Cross-thread safety + self._seq_no_lock = Lock() + self._msg_id_lock = Lock() + self._db_lock = Lock() # These values will be saved - self.server_address = None - self.port = None - self.auth_key = None - self.layer = 0 - self.salt = 0 # Signed long - self.entities = EntityDatabase() # Known and cached entities + self._dc_id = 0 + self._server_address = None + self._port = None + self._auth_key = None + + # Migrating from .json -> SQL + entities = self._check_migrate_json() + + self._conn = sqlite3.connect(self.filename, check_same_thread=False) + c = self._conn.cursor() + c.execute("select name from sqlite_master " + "where type='table' and name='version'") + if c.fetchone(): + # Tables already exist, check for the version + c.execute("select version from version") + version = c.fetchone()[0] + if version != CURRENT_VERSION: + self._upgrade_database(old=version) + self.save() + + # These values will be saved + c.execute('select * from sessions') + self._dc_id, self._server_address, self._port, key, = c.fetchone() + + from ..crypto import AuthKey + self._auth_key = AuthKey(data=key) + c.close() + else: + # Tables don't exist, create new ones + c.execute("create table version (version integer)") + c.execute( + """create table sessions ( + dc_id integer primary key, + server_address text, + port integer, + auth_key blob + ) without rowid""" + ) + c.execute( + """create table entities ( + id integer primary key, + hash integer not null, + username text, + phone integer, + name text + ) without rowid""" + ) + c.execute("insert into version values (1)") + # Migrating from JSON -> new table and may have entities + if entities: + c.executemany( + 'insert or replace into entities values (?,?,?,?,?)', + entities + ) + c.close() + self.save() + + def _check_migrate_json(self): + if file_exists(self.filename): + try: + with open(self.filename, encoding='utf-8') as f: + data = json.load(f) + self.delete() # Delete JSON file to create database + + self._port = data.get('port', self._port) + self._server_address = \ + data.get('server_address', self._server_address) + + from ..crypto import AuthKey + if data.get('auth_key_data', None) is not None: + key = b64decode(data['auth_key_data']) + self._auth_key = AuthKey(data=key) + + rows = [] + for p_id, p_hash in data.get('entities', []): + rows.append((p_id, p_hash, None, None, None)) + return rows + except (UnicodeDecodeError, json.decoder.JSONDecodeError): + return [] # No entities + + def _upgrade_database(self, old): + pass + + # Data from sessions should be kept as properties + # not to fetch the database every time we need it + def set_dc(self, dc_id, server_address, port): + self._dc_id = dc_id + self._server_address = server_address + self._port = port + self._update_session_table() + + @property + def server_address(self): + return self._server_address + + @property + def port(self): + return self._port + + @property + def auth_key(self): + return self._auth_key + + @auth_key.setter + def auth_key(self, value): + self._auth_key = value + self._update_session_table() + + def _update_session_table(self): + with self._db_lock: + c = self._conn.cursor() + c.execute('delete from sessions') + c.execute('insert into sessions values (?,?,?,?)', ( + self._dc_id, + self._server_address, + self._port, + self._auth_key.key if self._auth_key else b'' + )) + c.close() def save(self): """Saves the current session object as session_user_id.session""" - if not self.session_user_id or self._save_lock.locked(): - return - - with self._save_lock: - with open('{}.session'.format(self.session_user_id), 'w') as file: - out_dict = { - 'port': self.port, - 'salt': self.salt, - 'layer': self.layer, - 'server_address': self.server_address, - 'auth_key_data': - b64encode(self.auth_key.key).decode('ascii') - if self.auth_key else None - } - if self.save_entities: - out_dict['entities'] = self.entities.get_input_list() - - json.dump(out_dict, file) + with self._db_lock: + self._conn.commit() def delete(self): """Deletes the current session file""" + if self.filename == ':memory:': + return True try: - os.remove('{}.session'.format(self.session_user_id)) + os.remove(self.filename) return True except OSError: return False @@ -107,48 +216,7 @@ class Session: using this client and never logged out """ return [os.path.splitext(os.path.basename(f))[0] - for f in os.listdir('.') if f.endswith('.session')] - - @staticmethod - def try_load_or_create_new(session_user_id): - """Loads a saved session_user_id.session or creates a new one. - If session_user_id=None, later .save()'s will have no effect. - """ - if session_user_id is None: - return Session(None) - else: - path = '{}.session'.format(session_user_id) - result = Session(session_user_id) - if not file_exists(path): - return result - - try: - with open(path, 'r') as file: - data = json.load(file) - result.port = data.get('port', result.port) - result.salt = data.get('salt', result.salt) - # Keep while migrating from unsigned to signed salt - if result.salt > 0: - result.salt = struct.unpack( - 'q', struct.pack('Q', result.salt))[0] - - result.layer = data.get('layer', result.layer) - result.server_address = \ - data.get('server_address', result.server_address) - - # FIXME We need to import the AuthKey here or otherwise - # we get cyclic dependencies. - from ..crypto import AuthKey - if data.get('auth_key_data', None) is not None: - key = b64decode(data['auth_key_data']) - result.auth_key = AuthKey(data=key) - - result.entities = EntityDatabase(data.get('entities', [])) - - except (json.decoder.JSONDecodeError, UnicodeDecodeError): - pass - - return result + for f in os.listdir('.') if f.endswith(EXTENSION)] def generate_sequence(self, content_related): """Thread safe method to generates the next sequence number, @@ -188,9 +256,103 @@ class Session: correct = correct_msg_id >> 32 self.time_offset = correct - now - def process_entities(self, tlobject): - try: - if self.entities.process(tlobject): - self.save() # Save if any new entities got added - except: - pass + # Entity processing + + def process_entities(self, tlo): + """Processes all the found entities on the given TLObject, + unless .enabled is False. + + Returns True if new input entities were added. + """ + if not self.save_entities: + return + + if not isinstance(tlo, TLObject) and hasattr(tlo, '__iter__'): + # This may be a list of users already for instance + entities = tlo + else: + entities = [] + if hasattr(tlo, 'chats') and hasattr(tlo.chats, '__iter__'): + entities.extend(tlo.chats) + if hasattr(tlo, 'users') and hasattr(tlo.users, '__iter__'): + entities.extend(tlo.users) + if not entities: + return + + rows = [] # Rows to add (id, hash, username, phone, name) + for e in entities: + if not isinstance(e, TLObject): + continue + try: + p = utils.get_input_peer(e, allow_self=False) + marked_id = utils.get_peer_id(p, add_mark=True) + except ValueError: + continue + + p_hash = getattr(p, 'access_hash', 0) + if p_hash is None: + # Some users and channels seem to be returned without + # an 'access_hash', meaning Telegram doesn't want you + # to access them. This is the reason behind ensuring + # that the 'access_hash' is non-zero. See issue #354. + continue + + username = getattr(e, 'username', None) or None + if username is not None: + username = username.lower() + phone = getattr(e, 'phone', None) + name = utils.get_display_name(e) or None + rows.append((marked_id, p_hash, username, phone, name)) + if not rows: + return + + with self._db_lock: + self._conn.executemany( + 'insert or replace into entities values (?,?,?,?,?)', rows + ) + self.save() + + def get_input_entity(self, key): + """Parses the given string, integer or TLObject key into a + marked entity ID, which is then used to fetch the hash + from the database. + + If a callable key is given, every row will be fetched, + and passed as a tuple to a function, that should return + a true-like value when the desired row is found. + + Raises ValueError if it cannot be found. + """ + if isinstance(key, TLObject): + key = utils.get_input_peer(key) + if type(key).SUBCLASS_OF_ID == 0xc91c90b6: # crc32(b'InputPeer') + return key + key = utils.get_peer_id(key, add_mark=True) + + c = self._conn.cursor() + if isinstance(key, str): + phone = utils.parse_phone(key) + if phone: + c.execute('select id, hash from entities where phone=?', + (phone,)) + else: + username, _ = utils.parse_username(key) + c.execute('select id, hash from entities where username=?', + (username,)) + + if isinstance(key, int): + c.execute('select id, hash from entities where id=?', (key,)) + + result = c.fetchone() + c.close() + if result: + i, h = result # unpack resulting tuple + i, k = utils.resolve_id(i) # removes the mark and returns kind + if k == PeerUser: + return InputPeerUser(i, h) + elif k == PeerChat: + return InputPeerChat(i) + elif k == PeerChannel: + return InputPeerChannel(i, h) + else: + raise ValueError('Could not find input entity with key ', key) diff --git a/telethon/utils.py b/telethon/utils.py index 5e92b13d..0662a99d 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -5,6 +5,8 @@ to convert between an entity like an User, Chat, etc. into its Input version) import math from mimetypes import add_type, guess_extension +import re + from .tl import TLObject from .tl.types import ( Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull, @@ -24,6 +26,11 @@ from .tl.types import ( ) +USERNAME_RE = re.compile( + r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?' +) + + def get_display_name(entity): """Gets the input peer for the given "entity" (user, chat or channel) Returns None if it was not found""" @@ -305,6 +312,32 @@ def get_input_media(media, user_caption=None, is_photo=False): _raise_cast_fail(media, 'InputMedia') +def parse_phone(phone): + """Parses the given phone, or returns None if it's invalid""" + if isinstance(phone, int): + return str(phone) + else: + phone = re.sub(r'[+()\s-]', '', str(phone)) + if phone.isdigit(): + return phone + + +def parse_username(username): + """Parses the given username or channel access hash, given + a string, username or URL. Returns a tuple consisting of + both the stripped, lowercase username and whether it is + a joinchat/ hash (in which case is not lowercase'd). + """ + username = username.strip() + m = USERNAME_RE.match(username) + if m: + result = username[m.end():] + is_invite = bool(m.group(1)) + return result if is_invite else result.lower(), is_invite + else: + return username.lower(), False + + def get_peer_id(peer, add_mark=False): """Finds the ID of the given peer, and optionally converts it to the "bot api" format if 'add_mark' is set to True. diff --git a/telethon_tests/higher_level_test.py b/telethon_tests/higher_level_test.py index 7bd4b181..7433fac9 100644 --- a/telethon_tests/higher_level_test.py +++ b/telethon_tests/higher_level_test.py @@ -18,7 +18,7 @@ class HigherLevelTests(unittest.TestCase): @staticmethod def test_cdn_download(): client = TelegramClient(None, api_id, api_hash) - client.session.server_address = '149.154.167.40' + client.session.set_dc(0, '149.154.167.40', 80) assert client.connect() try: