From d80898ecc5e78073c581244c867abc715c62bef5 Mon Sep 17 00:00:00 2001 From: Humberto Gontijo <96875138+humbertogontijo@users.noreply.github.com> Date: Mon, 28 Jul 2025 12:03:31 -0300 Subject: [PATCH] Add experimental support for async sessions (#4667) There no plans for this to ever be non-experimental in v1. --- telethon/client/auth.py | 2 +- telethon/client/downloads.py | 9 ++-- telethon/client/telegrambaseclient.py | 69 ++++++++++++++++----------- telethon/client/updates.py | 19 ++++---- telethon/client/users.py | 11 +++-- telethon/network/mtprotosender.py | 2 +- telethon/utils.py | 9 ++++ 7 files changed, 75 insertions(+), 46 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 4ac1c805..0c9eddb0 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -552,7 +552,7 @@ class AuthMethods: self._authorized = False await self.disconnect() - self.session.delete() + await utils.maybe_async(self.session.delete()) self.session = None return True diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 23040f64..a3e7dd41 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -63,9 +63,12 @@ class _DirectDownloadIter(RequestIter): config = await self.client(functions.help.GetConfigRequest()) for option in config.dc_options: if option.ip_address == self.client.session.server_address: - self.client.session.set_dc( - option.id, option.ip_address, option.port) - self.client.session.save() + await utils.maybe_async( + self.client.session.set_dc( + option.id, option.ip_address, option.port + ) + ) + await utils.maybe_async(self.client.session.save()) break # TODO Figure out why the session may have the wrong DC ID diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 77768b78..370f228d 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -1,4 +1,5 @@ import abc +import inspect import re import asyncio import collections @@ -9,7 +10,7 @@ import typing import datetime import pathlib -from .. import version, helpers, __name__ as __base_name__ +from .. import utils, version, helpers, __name__ as __base_name__ from ..crypto import rsa from ..extensions import markdown from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy @@ -305,16 +306,6 @@ class TelegramBaseClient(abc.ABC): 'The given session must be a str or a Session instance.' ) - # ':' in session.server_address is True if it's an IPv6 address - if (not session.server_address or - (':' in session.server_address) != use_ipv6): - session.set_dc( - DEFAULT_DC_ID, - DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, - DEFAULT_PORT - ) - session.save() - self.flood_sleep_threshold = flood_sleep_threshold # TODO Use AsyncClassWrapper(session) @@ -546,6 +537,18 @@ class TelegramBaseClient(abc.ABC): elif self._loop != helpers.get_running_loop(): raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)') + # ':' in session.server_address is True if it's an IPv6 address + if (not self.session.server_address or + (':' in self.session.server_address) != self._use_ipv6): + await utils.maybe_async( + self.session.set_dc( + DEFAULT_DC_ID, + DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, + DEFAULT_PORT + ) + ) + await utils.maybe_async(self.session.save()) + if not await self._sender.connect(self._connection( self.session.server_address, self.session.port, @@ -558,12 +561,13 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key - self.session.save() + await utils.maybe_async(self.session.save()) try: # See comment when saving entities to understand this hack - self_id = self.session.get_input_entity(0).access_hash - self_user = self.session.get_input_entity(self_id) + self_entity = await utils.maybe_async(self.session.get_input_entity(0)) + self_id = self_entity.access_hash + self_user = await utils.maybe_async(self.session.get_input_entity(self_id)) self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) except ValueError: pass @@ -572,7 +576,8 @@ class TelegramBaseClient(abc.ABC): ss = SessionState(0, 0, False, 0, 0, 0, 0, None) cs = [] - for entity_id, state in self.session.get_update_states(): + update_states = await utils.maybe_async(self.session.get_update_states()) + for entity_id, state in update_states: if entity_id == 0: # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) @@ -582,7 +587,7 @@ class TelegramBaseClient(abc.ABC): self._message_box.load(ss, cs) for state in cs: try: - entity = self.session.get_input_entity(state.channel_id) + entity = await utils.maybe_async(self.session.get_input_entity(state.channel_id)) except ValueError: self._log[__name__].warning( 'No access_hash in cache for channel %s, will not catch up', state.channel_id) @@ -686,19 +691,27 @@ class TelegramBaseClient(abc.ABC): else: connection._proxy = proxy - def _save_states_and_entities(self: 'TelegramClient'): + async def _save_states_and_entities(self: 'TelegramClient'): # As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``. # This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # It doesn't matter if we put users in the list of chats. if self._mb_entity_cache.self_id: - self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])) + await utils.maybe_async( + self.session.process_entities( + types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) + ) + ) ss, cs = self._message_box.session_state() - self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) + await utils.maybe_async(self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))) now = datetime.datetime.now() # any datetime works; channels don't need it for channel_id, pts in cs.items(): - self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) + await utils.maybe_async( + self.session.set_update_state( + channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) + ) + ) async def _disconnect_coro(self: 'TelegramClient'): if self.session is None: @@ -730,9 +743,9 @@ class TelegramBaseClient(abc.ABC): await asyncio.wait(self._event_handler_tasks) self._event_handler_tasks.clear() - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.close() + await utils.maybe_async(self.session.close()) async def _disconnect(self: 'TelegramClient'): """ @@ -753,22 +766,22 @@ class TelegramBaseClient(abc.ABC): self._log[__name__].info('Reconnecting to new data center %s', new_dc) dc = await self._get_dc(new_dc) - self.session.set_dc(dc.id, dc.ip_address, dc.port) + await utils.maybe_async(self.session.set_dc(dc.id, dc.ip_address, dc.port)) # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self._sender.auth_key.key = None self.session.auth_key = None - self.session.save() + await utils.maybe_async(self.session.save()) await self._disconnect() return await self.connect() - def _auth_key_callback(self: 'TelegramClient', auth_key): + async def _auth_key_callback(self: 'TelegramClient', auth_key): """ Callback from the sender whenever it needed to generate a new authorization key. This means we are not authorized. """ self.session.auth_key = auth_key - self.session.save() + await utils.maybe_async(self.session.save()) # endregion @@ -894,8 +907,8 @@ class TelegramBaseClient(abc.ABC): session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) - session = self.session.clone() - session.set_dc(dc.id, dc.ip_address, dc.port) + session = await utils.maybe_async(self.session.clone()) + await utils.maybe_async(session.set_dc(dc.id, dc.ip_address, dc.port)) self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') diff --git a/telethon/client/updates.py b/telethon/client/updates.py index b5a9aa96..27410ce8 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -289,7 +289,7 @@ class UpdateMethods: len(self._mb_entity_cache), self._entity_cache_limit ) - self._save_states_and_entities() + await self._save_states_and_entities() self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map) if len(self._mb_entity_cache) >= self._entity_cache_limit: warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit') @@ -343,7 +343,8 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for account updates') - updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats)) + updates_to_dispatch.extend(_preprocess_updates) continue get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) @@ -441,7 +442,8 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) - updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats)) + updates_to_dispatch.extend(_preprocess_updates) continue deadline = self._message_box.check_deadlines() @@ -462,7 +464,8 @@ class UpdateMethods: except GapError: continue # get(_channel)_difference will start returning requests - updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats)) + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(processed, users, chats)) + updates_to_dispatch.extend(_preprocess_updates) except asyncio.CancelledError: pass except Exception as e: @@ -470,9 +473,9 @@ class UpdateMethods: self._updates_error = e await self.disconnect() - def _preprocess_updates(self, updates, users, chats): + async def _preprocess_updates(self, updates, users, chats): self._mb_entity_cache.extend(users, chats) - self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) + await utils.maybe_async(self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))) entities = {utils.get_peer_id(x): x for x in itertools.chain(users, chats)} for u in updates: @@ -515,9 +518,9 @@ class UpdateMethods: # inserted because this is a rather expensive operation # (default's sqlite3 takes ~0.1s to commit changes). Do # it every minute instead. No-op if there's nothing new. - self._save_states_and_entities() + await self._save_states_and_entities() - self.session.save() + await utils.maybe_async(self.session.save()) async def _dispatch_update(self: 'TelegramClient', update): # TODO only used for AlbumHack, and MessageBox is not really designed for this diff --git a/telethon/client/users.py b/telethon/client/users.py index acb8f55c..994856bb 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -80,7 +80,7 @@ class UserMethods: exceptions.append(e) results.append(None) continue - self.session.process_entities(result) + await utils.maybe_async(self.session.process_entities(result)) exceptions.append(None) results.append(result) request_index += 1 @@ -90,7 +90,7 @@ class UserMethods: return results else: result = await future - self.session.process_entities(result) + await utils.maybe_async(self.session.process_entities(result)) return result except (errors.ServerError, errors.RpcCallFailError, errors.RpcMcgetFailError, errors.InterdcCallErrorError, @@ -435,7 +435,8 @@ class UserMethods: # No InputPeer, cached peer, or known string. Fetch from disk cache try: - return self.session.get_input_entity(peer) + input_entity = await utils.maybe_async(self.session.get_input_entity(peer)) + return input_entity except ValueError: pass @@ -574,8 +575,8 @@ class UserMethods: pass try: # Nobody with this username, maybe it's an exact name/title - return await self.get_entity( - self.session.get_input_entity(string)) + input_entity = await utils.maybe_async(self.session.get_input_entity(string)) + return await self.get_entity(input_entity) except ValueError: pass diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index d53f9ce8..59cfab6c 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -302,7 +302,7 @@ class MTProtoSender: # notify whenever we change it. This is crucial when we # switch to different data centers. if self._auth_key_callback: - self._auth_key_callback(self.auth_key) + await self._auth_key_callback(self.auth_key) self._log.debug('auth_key generation success!') return True diff --git a/telethon/utils.py b/telethon/utils.py index 234dd71e..3588e044 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -14,6 +14,7 @@ import os import pathlib import re import struct +import warnings from collections import namedtuple from mimetypes import guess_extension from types import GeneratorType @@ -1557,3 +1558,11 @@ def _photo_size_byte_count(size): return max(size.sizes) else: return None + + +async def maybe_async(coro): + result = coro + if inspect.isawaitable(result): + warnings.warn('Using async sessions support is an experimental feature') + result = await result + return result