From 2df5e07c7c1ec03b22f91deb99dc1b0a39118598 Mon Sep 17 00:00:00 2001 From: humbertogontijo Date: Thu, 9 Nov 2023 00:17:49 -0300 Subject: [PATCH] Feat adding async session support --- telethon/client/auth.py | 4 +- telethon/client/downloads.py | 9 ++-- telethon/client/telegrambaseclient.py | 67 +++++++++++++++++++++------ telethon/client/updates.py | 8 ++-- telethon/client/users.py | 20 ++++++-- telethon/network/mtprotosender.py | 2 +- 6 files changed, 82 insertions(+), 28 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 9ca5b458..85615814 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -540,7 +540,9 @@ class AuthMethods: self._authorized = False await self.disconnect() - self.session.delete() + delete = self.session.delete() + if inspect.isawaitable(delete): + await delete self.session = None return True diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 3c9fa2d1..0b2cb5e2 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -53,9 +53,12 @@ class _DirectDownloadIter(RequestIter): config = await self.client(functions.help.GetConfigRequest()) for option in config.dc_options: if option.ip_address == self.client.session.server_address: - self.client.session.set_dc( - option.id, option.ip_address, option.port) - self.client.session.save() + set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port) + if inspect.isawaitable(set_dc): + await set_dc + save = self.client.session.save() + if inspect.isawaitable(save): + await save break # TODO Figure out why the session may have the wrong DC ID diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index dd127654..bfc8105c 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -1,4 +1,5 @@ import abc +import inspect import re import asyncio import collections @@ -553,12 +554,19 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save try: # See comment when saving entities to understand this hack - self_id = self.session.get_input_entity(0).access_hash + self_entity = self.session.get_input_entity(0) + if inspect.isawaitable(self_entity): + self_entity = await self_entity + self_id = self_entity.access_hash self_user = self.session.get_input_entity(self_id) + if inspect.isawaitable(self_user): + self_user = await self_user self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) except ValueError: pass @@ -567,7 +575,10 @@ class TelegramBaseClient(abc.ABC): ss = SessionState(0, 0, False, 0, 0, 0, 0, None) cs = [] - for entity_id, state in self.session.get_update_states(): + update_states = self.session.get_update_states() + if inspect.isawaitable(update_states): + update_states = await update_states + for entity_id, state in update_states: if entity_id == 0: # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) @@ -578,6 +589,8 @@ class TelegramBaseClient(abc.ABC): for state in cs: try: entity = self.session.get_input_entity(state.channel_id) + if inspect.isawaitable(entity): + entity = await entity except ValueError: self._log[__name__].warning( 'No access_hash in cache for channel %s, will not catch up', state.channel_id) @@ -681,23 +694,37 @@ class TelegramBaseClient(abc.ABC): else: connection._proxy = proxy - def _save_states_and_entities(self: 'TelegramClient'): + async def _save_states_and_entities(self: 'TelegramClient'): entities = self._mb_entity_cache.get_all_entities() # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # It doesn't matter if we put users in the list of chats. - self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], [])) + process_entities = self.session.process_entities( + types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []) + ) + if inspect.isawaitable(process_entities): + await process_entities # As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``. # This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved. if self._mb_entity_cache.self_id: - self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])) + process_entities = self.session.process_entities( + types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) + ) + if inspect.isawaitable(process_entities): + await process_entities ss, cs = self._message_box.session_state() - self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) + update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) + if inspect.isawaitable(update_state): + await update_state now = datetime.datetime.now() # any datetime works; channels don't need it for channel_id, pts in cs.items(): - self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) + update_state = self.session.set_update_state( + channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) + ) + if inspect.isawaitable(update_state): + await update_state async def _disconnect_coro(self: 'TelegramClient'): if self.session is None: @@ -729,9 +756,11 @@ class TelegramBaseClient(abc.ABC): await asyncio.wait(self._event_handler_tasks) self._event_handler_tasks.clear() - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.close() + close = self.session.close() + if inspect.isawaitable(close): + await close async def _disconnect(self: 'TelegramClient'): """ @@ -752,22 +781,28 @@ class TelegramBaseClient(abc.ABC): self._log[__name__].info('Reconnecting to new data center %s', new_dc) dc = await self._get_dc(new_dc) - self.session.set_dc(dc.id, dc.ip_address, dc.port) + set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port) + if inspect.isawaitable(set_dc): + await set_dc # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self._sender.auth_key.key = None self.session.auth_key = None - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save await self._disconnect() return await self.connect() - def _auth_key_callback(self: 'TelegramClient', auth_key): + async def _auth_key_callback(self: 'TelegramClient', auth_key): """ Callback from the sender whenever it needed to generate a new authorization key. This means we are not authorized. """ self.session.auth_key = auth_key - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save # endregion @@ -892,6 +927,8 @@ class TelegramBaseClient(abc.ABC): if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) session = self.session.clone() + if inspect.isawaitable(session): + session = await session session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session @@ -907,7 +944,7 @@ class TelegramBaseClient(abc.ABC): # We won't be calling GetConfigRequest because it's only called # when needed by ._get_dc, and also it's static so it's likely # set already. Avoid invoking non-CDN methods by not syncing updates. - client.connect(_sync_updates=False) + await client.connect(_sync_updates=False) return client # endregion diff --git a/telethon/client/updates.py b/telethon/client/updates.py index f06b28e8..6b283b55 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -289,7 +289,7 @@ class UpdateMethods: len(self._mb_entity_cache), self._entity_cache_limit ) - self._save_states_and_entities() + await self._save_states_and_entities() self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map) if len(self._mb_entity_cache) >= self._entity_cache_limit: warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit') @@ -514,9 +514,11 @@ class UpdateMethods: # inserted because this is a rather expensive operation # (default's sqlite3 takes ~0.1s to commit changes). Do # it every minute instead. No-op if there's nothing new. - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.save() + save = self.session.save() + if inspect.isawaitable(save): + await save async def _dispatch_update(self: 'TelegramClient', update): # TODO only used for AlbumHack, and MessageBox is not really designed for this diff --git a/telethon/client/users.py b/telethon/client/users.py index eda3e040..8f1feae5 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -1,5 +1,6 @@ import asyncio import datetime +import inspect import itertools import time import typing @@ -75,7 +76,9 @@ class UserMethods: exceptions.append(e) results.append(None) continue - self.session.process_entities(result) + process_entities = self.session.process_entities(result) + if inspect.isawaitable(process_entities): + await process_entities exceptions.append(None) results.append(result) request_index += 1 @@ -85,7 +88,9 @@ class UserMethods: return results else: result = await future - self.session.process_entities(result) + process_entities = self.session.process_entities(result) + if inspect.isawaitable(process_entities): + await process_entities return result except (errors.ServerError, errors.RpcCallFailError, errors.RpcMcgetFailError, errors.InterdcCallErrorError, @@ -428,7 +433,10 @@ class UserMethods: # No InputPeer, cached peer, or known string. Fetch from disk cache try: - return self.session.get_input_entity(peer) + input_entity = self.session.get_input_entity(peer) + if inspect.isawaitable(input_entity): + input_entity = await input_entity + return input_entity except ValueError: pass @@ -567,8 +575,10 @@ class UserMethods: pass try: # Nobody with this username, maybe it's an exact name/title - return await self.get_entity( - self.session.get_input_entity(string)) + input_entity = self.session.get_input_entity(string) + if inspect.isawaitable(input_entity): + input_entity = await input_entity + return await self.get_entity(input_entity) except ValueError: pass diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 6c3e30c1..db4518c1 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -302,7 +302,7 @@ class MTProtoSender: # notify whenever we change it. This is crucial when we # switch to different data centers. if self._auth_key_callback: - self._auth_key_callback(self.auth_key) + await self._auth_key_callback(self.auth_key) self._log.debug('auth_key generation success!') return True