Add experimental support for async sessions (#4667)

There no plans for this to ever be non-experimental in v1.
This commit is contained in:
Humberto Gontijo 2025-07-28 12:03:31 -03:00 committed by GitHub
parent 45a546a675
commit d80898ecc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 75 additions and 46 deletions

View File

@ -552,7 +552,7 @@ class AuthMethods:
self._authorized = False self._authorized = False
await self.disconnect() await self.disconnect()
self.session.delete() await utils.maybe_async(self.session.delete())
self.session = None self.session = None
return True return True

View File

@ -63,9 +63,12 @@ class _DirectDownloadIter(RequestIter):
config = await self.client(functions.help.GetConfigRequest()) config = await self.client(functions.help.GetConfigRequest())
for option in config.dc_options: for option in config.dc_options:
if option.ip_address == self.client.session.server_address: if option.ip_address == self.client.session.server_address:
self.client.session.set_dc( await utils.maybe_async(
option.id, option.ip_address, option.port) self.client.session.set_dc(
self.client.session.save() option.id, option.ip_address, option.port
)
)
await utils.maybe_async(self.client.session.save())
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID

View File

@ -1,4 +1,5 @@
import abc import abc
import inspect
import re import re
import asyncio import asyncio
import collections import collections
@ -9,7 +10,7 @@ import typing
import datetime import datetime
import pathlib import pathlib
from .. import version, helpers, __name__ as __base_name__ from .. import utils, version, helpers, __name__ as __base_name__
from ..crypto import rsa from ..crypto import rsa
from ..extensions import markdown from ..extensions import markdown
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
@ -305,16 +306,6 @@ class TelegramBaseClient(abc.ABC):
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
# ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or
(':' in session.server_address) != use_ipv6):
session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
session.save()
self.flood_sleep_threshold = flood_sleep_threshold self.flood_sleep_threshold = flood_sleep_threshold
# TODO Use AsyncClassWrapper(session) # TODO Use AsyncClassWrapper(session)
@ -546,6 +537,18 @@ class TelegramBaseClient(abc.ABC):
elif self._loop != helpers.get_running_loop(): elif self._loop != helpers.get_running_loop():
raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)') raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)')
# ':' in session.server_address is True if it's an IPv6 address
if (not self.session.server_address or
(':' in self.session.server_address) != self._use_ipv6):
await utils.maybe_async(
self.session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
)
await utils.maybe_async(self.session.save())
if not await self._sender.connect(self._connection( if not await self._sender.connect(self._connection(
self.session.server_address, self.session.server_address,
self.session.port, self.session.port,
@ -558,12 +561,13 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
self.session.save() await utils.maybe_async(self.session.save())
try: try:
# See comment when saving entities to understand this hack # See comment when saving entities to understand this hack
self_id = self.session.get_input_entity(0).access_hash self_entity = await utils.maybe_async(self.session.get_input_entity(0))
self_user = self.session.get_input_entity(self_id) self_id = self_entity.access_hash
self_user = await utils.maybe_async(self.session.get_input_entity(self_id))
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
except ValueError: except ValueError:
pass pass
@ -572,7 +576,8 @@ class TelegramBaseClient(abc.ABC):
ss = SessionState(0, 0, False, 0, 0, 0, 0, None) ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = [] cs = []
for entity_id, state in self.session.get_update_states(): update_states = await utils.maybe_async(self.session.get_update_states())
for entity_id, state in update_states:
if entity_id == 0: if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
@ -582,7 +587,7 @@ class TelegramBaseClient(abc.ABC):
self._message_box.load(ss, cs) self._message_box.load(ss, cs)
for state in cs: for state in cs:
try: try:
entity = self.session.get_input_entity(state.channel_id) entity = await utils.maybe_async(self.session.get_input_entity(state.channel_id))
except ValueError: except ValueError:
self._log[__name__].warning( self._log[__name__].warning(
'No access_hash in cache for channel %s, will not catch up', state.channel_id) 'No access_hash in cache for channel %s, will not catch up', state.channel_id)
@ -686,19 +691,27 @@ class TelegramBaseClient(abc.ABC):
else: else:
connection._proxy = proxy connection._proxy = proxy
def _save_states_and_entities(self: 'TelegramClient'): async def _save_states_and_entities(self: 'TelegramClient'):
# As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``. # As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``.
# This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved. # This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved.
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats. # It doesn't matter if we put users in the list of chats.
if self._mb_entity_cache.self_id: if self._mb_entity_cache.self_id:
self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])) await utils.maybe_async(
self.session.process_entities(
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
)
)
ss, cs = self._message_box.session_state() ss, cs = self._message_box.session_state()
self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) await utils.maybe_async(self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)))
now = datetime.datetime.now() # any datetime works; channels don't need it now = datetime.datetime.now() # any datetime works; channels don't need it
for channel_id, pts in cs.items(): for channel_id, pts in cs.items():
self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) await utils.maybe_async(
self.session.set_update_state(
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
)
)
async def _disconnect_coro(self: 'TelegramClient'): async def _disconnect_coro(self: 'TelegramClient'):
if self.session is None: if self.session is None:
@ -730,9 +743,9 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._event_handler_tasks) await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear() self._event_handler_tasks.clear()
self._save_states_and_entities() await self._save_states_and_entities()
self.session.close() await utils.maybe_async(self.session.close())
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -753,22 +766,22 @@ class TelegramBaseClient(abc.ABC):
self._log[__name__].info('Reconnecting to new data center %s', new_dc) self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc) dc = await self._get_dc(new_dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) await utils.maybe_async(self.session.set_dc(dc.id, dc.ip_address, dc.port))
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
self.session.save() await utils.maybe_async(self.session.save())
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
def _auth_key_callback(self: 'TelegramClient', auth_key): async def _auth_key_callback(self: 'TelegramClient', auth_key):
""" """
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self.session.auth_key = auth_key
self.session.save() await utils.maybe_async(self.session.save())
# endregion # endregion
@ -894,8 +907,8 @@ class TelegramBaseClient(abc.ABC):
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = await utils.maybe_async(self.session.clone())
session.set_dc(dc.id, dc.ip_address, dc.port) await utils.maybe_async(session.set_dc(dc.id, dc.ip_address, dc.port))
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')

View File

@ -289,7 +289,7 @@ class UpdateMethods:
len(self._mb_entity_cache), len(self._mb_entity_cache),
self._entity_cache_limit self._entity_cache_limit
) )
self._save_states_and_entities() await self._save_states_and_entities()
self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map) self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map)
if len(self._mb_entity_cache) >= self._entity_cache_limit: if len(self._mb_entity_cache) >= self._entity_cache_limit:
warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit') warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit')
@ -343,7 +343,8 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for account updates') self._log[__name__].info('Got difference for account updates')
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
updates_to_dispatch.extend(_preprocess_updates)
continue continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
@ -441,7 +442,8 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
updates_to_dispatch.extend(_preprocess_updates)
continue continue
deadline = self._message_box.check_deadlines() deadline = self._message_box.check_deadlines()
@ -462,7 +464,8 @@ class UpdateMethods:
except GapError: except GapError:
continue # get(_channel)_difference will start returning requests continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats)) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(processed, users, chats))
updates_to_dispatch.extend(_preprocess_updates)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
@ -470,9 +473,9 @@ class UpdateMethods:
self._updates_error = e self._updates_error = e
await self.disconnect() await self.disconnect()
def _preprocess_updates(self, updates, users, chats): async def _preprocess_updates(self, updates, users, chats):
self._mb_entity_cache.extend(users, chats) self._mb_entity_cache.extend(users, chats)
self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) await utils.maybe_async(self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)))
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)} for x in itertools.chain(users, chats)}
for u in updates: for u in updates:
@ -515,9 +518,9 @@ class UpdateMethods:
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
self._save_states_and_entities() await self._save_states_and_entities()
self.session.save() await utils.maybe_async(self.session.save())
async def _dispatch_update(self: 'TelegramClient', update): async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this # TODO only used for AlbumHack, and MessageBox is not really designed for this

View File

@ -80,7 +80,7 @@ class UserMethods:
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) await utils.maybe_async(self.session.process_entities(result))
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -90,7 +90,7 @@ class UserMethods:
return results return results
else: else:
result = await future result = await future
self.session.process_entities(result) await utils.maybe_async(self.session.process_entities(result))
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError, errors.InterdcCallErrorError, errors.RpcMcgetFailError, errors.InterdcCallErrorError,
@ -435,7 +435,8 @@ class UserMethods:
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
return self.session.get_input_entity(peer) input_entity = await utils.maybe_async(self.session.get_input_entity(peer))
return input_entity
except ValueError: except ValueError:
pass pass
@ -574,8 +575,8 @@ class UserMethods:
pass pass
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( input_entity = await utils.maybe_async(self.session.get_input_entity(string))
self.session.get_input_entity(string)) return await self.get_entity(input_entity)
except ValueError: except ValueError:
pass pass

View File

@ -302,7 +302,7 @@ class MTProtoSender:
# notify whenever we change it. This is crucial when we # notify whenever we change it. This is crucial when we
# switch to different data centers. # switch to different data centers.
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(self.auth_key) await self._auth_key_callback(self.auth_key)
self._log.debug('auth_key generation success!') self._log.debug('auth_key generation success!')
return True return True

View File

@ -14,6 +14,7 @@ import os
import pathlib import pathlib
import re import re
import struct import struct
import warnings
from collections import namedtuple from collections import namedtuple
from mimetypes import guess_extension from mimetypes import guess_extension
from types import GeneratorType from types import GeneratorType
@ -1557,3 +1558,11 @@ def _photo_size_byte_count(size):
return max(size.sizes) return max(size.sizes)
else: else:
return None return None
async def maybe_async(coro):
result = coro
if inspect.isawaitable(result):
warnings.warn('Using async sessions support is an experimental feature')
result = await result
return result