diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 4ac1c805..8428a192 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -552,7 +552,9 @@ class AuthMethods: self._authorized = False await self.disconnect() - self.session.delete() + delete = self.session.delete() + if inspect.isawaitable(delete): + await delete self.session = None return True diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 23040f64..809e5ca3 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -63,9 +63,12 @@ class _DirectDownloadIter(RequestIter): config = await self.client(functions.help.GetConfigRequest()) for option in config.dc_options: if option.ip_address == self.client.session.server_address: - self.client.session.set_dc( - option.id, option.ip_address, option.port) - self.client.session.save() + set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port) + if inspect.isawaitable(set_dc): + await set_dc + save = self.client.session.save() + if inspect.isawaitable(save): + await save break # TODO Figure out why the session may have the wrong DC ID diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 77768b78..14d40b7c 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -1,4 +1,5 @@ import abc +import inspect import re import asyncio import collections @@ -305,16 +306,6 @@ class TelegramBaseClient(abc.ABC): 'The given session must be a str or a Session instance.' ) - # ':' in session.server_address is True if it's an IPv6 address - if (not session.server_address or - (':' in session.server_address) != use_ipv6): - session.set_dc( - DEFAULT_DC_ID, - DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, - DEFAULT_PORT - ) - session.save() - self.flood_sleep_threshold = flood_sleep_threshold # TODO Use AsyncClassWrapper(session) @@ -546,6 +537,20 @@ class TelegramBaseClient(abc.ABC): elif self._loop != helpers.get_running_loop(): raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)') + # ':' in session.server_address is True if it's an IPv6 address + if (not self.session.server_address or + (':' in self.session.server_address) != self._use_ipv6): + set_dc = self.session.set_dc( + DEFAULT_DC_ID, + DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, + DEFAULT_PORT + ) + if inspect.isawaitable(set_dc): + await set_dc + save = self.session.save() + if inspect.isawaitable(save): + await save + if not await self._sender.connect(self._connection( self.session.server_address, self.session.port, @@ -558,12 +563,19 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save try: # See comment when saving entities to understand this hack - self_id = self.session.get_input_entity(0).access_hash + self_entity = self.session.get_input_entity(0) + if inspect.isawaitable(self_entity): + self_entity = await self_entity + self_id = self_entity.access_hash self_user = self.session.get_input_entity(self_id) + if inspect.isawaitable(self_user): + self_user = await self_user self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) except ValueError: pass @@ -572,7 +584,10 @@ class TelegramBaseClient(abc.ABC): ss = SessionState(0, 0, False, 0, 0, 0, 0, None) cs = [] - for entity_id, state in self.session.get_update_states(): + update_states = self.session.get_update_states() + if inspect.isawaitable(update_states): + update_states = await update_states + for entity_id, state in update_states: if entity_id == 0: # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) @@ -583,6 +598,8 @@ class TelegramBaseClient(abc.ABC): for state in cs: try: entity = self.session.get_input_entity(state.channel_id) + if inspect.isawaitable(entity): + entity = await entity except ValueError: self._log[__name__].warning( 'No access_hash in cache for channel %s, will not catch up', state.channel_id) @@ -686,19 +703,29 @@ class TelegramBaseClient(abc.ABC): else: connection._proxy = proxy - def _save_states_and_entities(self: 'TelegramClient'): + async def _save_states_and_entities(self: 'TelegramClient'): # As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``. # This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # It doesn't matter if we put users in the list of chats. if self._mb_entity_cache.self_id: - self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])) + process_entities = self.session.process_entities( + types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) + ) + if inspect.isawaitable(process_entities): + await process_entities ss, cs = self._message_box.session_state() - self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) + update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) + if inspect.isawaitable(update_state): + await update_state now = datetime.datetime.now() # any datetime works; channels don't need it for channel_id, pts in cs.items(): - self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) + update_state = self.session.set_update_state( + channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) + ) + if inspect.isawaitable(update_state): + await update_state async def _disconnect_coro(self: 'TelegramClient'): if self.session is None: @@ -730,9 +757,11 @@ class TelegramBaseClient(abc.ABC): await asyncio.wait(self._event_handler_tasks) self._event_handler_tasks.clear() - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.close() + close = self.session.close() + if inspect.isawaitable(close): + await close async def _disconnect(self: 'TelegramClient'): """ @@ -753,22 +782,28 @@ class TelegramBaseClient(abc.ABC): self._log[__name__].info('Reconnecting to new data center %s', new_dc) dc = await self._get_dc(new_dc) - self.session.set_dc(dc.id, dc.ip_address, dc.port) + set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port) + if inspect.isawaitable(set_dc): + await set_dc # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self._sender.auth_key.key = None self.session.auth_key = None - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save await self._disconnect() return await self.connect() - def _auth_key_callback(self: 'TelegramClient', auth_key): + async def _auth_key_callback(self: 'TelegramClient', auth_key): """ Callback from the sender whenever it needed to generate a new authorization key. This means we are not authorized. """ self.session.auth_key = auth_key - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save # endregion @@ -895,7 +930,11 @@ class TelegramBaseClient(abc.ABC): if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) session = self.session.clone() - session.set_dc(dc.id, dc.ip_address, dc.port) + if inspect.isawaitable(session): + session = await session + set_dc = session.set_dc(dc.id, dc.ip_address, dc.port) + if inspect.isawaitable(set_dc): + await set_dc self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') diff --git a/telethon/client/updates.py b/telethon/client/updates.py index b5a9aa96..9aa57e14 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -289,7 +289,7 @@ class UpdateMethods: len(self._mb_entity_cache), self._entity_cache_limit ) - self._save_states_and_entities() + await self._save_states_and_entities() self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map) if len(self._mb_entity_cache) >= self._entity_cache_limit: warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit') @@ -343,7 +343,10 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for account updates') - updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) + _preprocess_updates = self._preprocess_updates(updates, users, chats) + if inspect.isawaitable(_preprocess_updates): + _preprocess_updates = await _preprocess_updates + updates_to_dispatch.extend(_preprocess_updates) continue get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) @@ -441,7 +444,10 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) - updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) + _preprocess_updates = self._preprocess_updates(updates, users, chats) + if inspect.isawaitable(_preprocess_updates): + _preprocess_updates = await _preprocess_updates + updates_to_dispatch.extend(_preprocess_updates) continue deadline = self._message_box.check_deadlines() @@ -462,7 +468,10 @@ class UpdateMethods: except GapError: continue # get(_channel)_difference will start returning requests - updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats)) + _preprocess_updates = self._preprocess_updates(processed, users, chats) + if inspect.isawaitable(_preprocess_updates): + _preprocess_updates = await _preprocess_updates + updates_to_dispatch.extend(_preprocess_updates) except asyncio.CancelledError: pass except Exception as e: @@ -470,9 +479,11 @@ class UpdateMethods: self._updates_error = e await self.disconnect() - def _preprocess_updates(self, updates, users, chats): + async def _preprocess_updates(self, updates, users, chats): self._mb_entity_cache.extend(users, chats) - self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) + process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) + if inspect.isawaitable(process_entities): + await process_entities entities = {utils.get_peer_id(x): x for x in itertools.chain(users, chats)} for u in updates: @@ -515,9 +526,11 @@ class UpdateMethods: # inserted because this is a rather expensive operation # (default's sqlite3 takes ~0.1s to commit changes). Do # it every minute instead. No-op if there's nothing new. - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save async def _dispatch_update(self: 'TelegramClient', update): # TODO only used for AlbumHack, and MessageBox is not really designed for this diff --git a/telethon/client/users.py b/telethon/client/users.py index acb8f55c..88b344c6 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -1,5 +1,6 @@ import asyncio import datetime +import inspect import itertools import time import typing @@ -80,7 +81,9 @@ class UserMethods: exceptions.append(e) results.append(None) continue - self.session.process_entities(result) + process_entities = self.session.process_entities(result) + if inspect.isawaitable(process_entities): + await process_entities exceptions.append(None) results.append(result) request_index += 1 @@ -90,7 +93,9 @@ class UserMethods: return results else: result = await future - self.session.process_entities(result) + process_entities = self.session.process_entities(result) + if inspect.isawaitable(process_entities): + await process_entities return result except (errors.ServerError, errors.RpcCallFailError, errors.RpcMcgetFailError, errors.InterdcCallErrorError, @@ -435,7 +440,10 @@ class UserMethods: # No InputPeer, cached peer, or known string. Fetch from disk cache try: - return self.session.get_input_entity(peer) + input_entity = self.session.get_input_entity(peer) + if inspect.isawaitable(input_entity): + input_entity = await input_entity + return input_entity except ValueError: pass @@ -574,8 +582,10 @@ class UserMethods: pass try: # Nobody with this username, maybe it's an exact name/title - return await self.get_entity( - self.session.get_input_entity(string)) + input_entity = self.session.get_input_entity(string) + if inspect.isawaitable(input_entity): + input_entity = await input_entity + return await self.get_entity(input_entity) except ValueError: pass diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index d53f9ce8..59cfab6c 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -302,7 +302,7 @@ class MTProtoSender: # notify whenever we change it. This is crucial when we # switch to different data centers. if self._auth_key_callback: - self._auth_key_callback(self.auth_key) + await self._auth_key_callback(self.auth_key) self._log.debug('auth_key generation success!') return True