Add async sessions while keeping retro compatibility

This commit is contained in:
humbertogontijo 2025-07-25 13:45:34 -03:00
parent 01af2fcca3
commit 4742414247
6 changed files with 109 additions and 42 deletions

View File

@ -552,7 +552,9 @@ class AuthMethods:
self._authorized = False self._authorized = False
await self.disconnect() await self.disconnect()
self.session.delete() delete = self.session.delete()
if inspect.isawaitable(delete):
await delete
self.session = None self.session = None
return True return True

View File

@ -63,9 +63,12 @@ class _DirectDownloadIter(RequestIter):
config = await self.client(functions.help.GetConfigRequest()) config = await self.client(functions.help.GetConfigRequest())
for option in config.dc_options: for option in config.dc_options:
if option.ip_address == self.client.session.server_address: if option.ip_address == self.client.session.server_address:
self.client.session.set_dc( set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port)
option.id, option.ip_address, option.port) if inspect.isawaitable(set_dc):
self.client.session.save() await set_dc
save = self.client.session.save()
if inspect.isawaitable(save):
await save
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID

View File

@ -1,4 +1,5 @@
import abc import abc
import inspect
import re import re
import asyncio import asyncio
import collections import collections
@ -305,16 +306,6 @@ class TelegramBaseClient(abc.ABC):
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
# ':' in session.server_address is True if it's an IPv6 address
if (not session.server_address or
(':' in session.server_address) != use_ipv6):
session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
session.save()
self.flood_sleep_threshold = flood_sleep_threshold self.flood_sleep_threshold = flood_sleep_threshold
# TODO Use AsyncClassWrapper(session) # TODO Use AsyncClassWrapper(session)
@ -546,6 +537,20 @@ class TelegramBaseClient(abc.ABC):
elif self._loop != helpers.get_running_loop(): elif self._loop != helpers.get_running_loop():
raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)') raise RuntimeError('The asyncio event loop must not change after connection (see the FAQ for details)')
# ':' in session.server_address is True if it's an IPv6 address
if (not self.session.server_address or
(':' in self.session.server_address) != self._use_ipv6):
set_dc = self.session.set_dc(
DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT
)
if inspect.isawaitable(set_dc):
await set_dc
save = self.session.save()
if inspect.isawaitable(save):
await save
if not await self._sender.connect(self._connection( if not await self._sender.connect(self._connection(
self.session.server_address, self.session.server_address,
self.session.port, self.session.port,
@ -558,12 +563,19 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
self.session.save() save = self.session.save()
if inspect.isawaitable(save):
await save
try: try:
# See comment when saving entities to understand this hack # See comment when saving entities to understand this hack
self_id = self.session.get_input_entity(0).access_hash self_entity = self.session.get_input_entity(0)
if inspect.isawaitable(self_entity):
self_entity = await self_entity
self_id = self_entity.access_hash
self_user = self.session.get_input_entity(self_id) self_user = self.session.get_input_entity(self_id)
if inspect.isawaitable(self_user):
self_user = await self_user
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
except ValueError: except ValueError:
pass pass
@ -572,7 +584,10 @@ class TelegramBaseClient(abc.ABC):
ss = SessionState(0, 0, False, 0, 0, 0, 0, None) ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = [] cs = []
for entity_id, state in self.session.get_update_states(): update_states = self.session.get_update_states()
if inspect.isawaitable(update_states):
update_states = await update_states
for entity_id, state in update_states:
if entity_id == 0: if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
@ -583,6 +598,8 @@ class TelegramBaseClient(abc.ABC):
for state in cs: for state in cs:
try: try:
entity = self.session.get_input_entity(state.channel_id) entity = self.session.get_input_entity(state.channel_id)
if inspect.isawaitable(entity):
entity = await entity
except ValueError: except ValueError:
self._log[__name__].warning( self._log[__name__].warning(
'No access_hash in cache for channel %s, will not catch up', state.channel_id) 'No access_hash in cache for channel %s, will not catch up', state.channel_id)
@ -686,19 +703,29 @@ class TelegramBaseClient(abc.ABC):
else: else:
connection._proxy = proxy connection._proxy = proxy
def _save_states_and_entities(self: 'TelegramClient'): async def _save_states_and_entities(self: 'TelegramClient'):
# As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``. # As a hack to not need to change the session files, save ourselves with ``id=0`` and ``access_hash`` of our ``id``.
# This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved. # This way it is possible to determine our own ID by querying for 0. However, whether we're a bot is not saved.
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats. # It doesn't matter if we put users in the list of chats.
if self._mb_entity_cache.self_id: if self._mb_entity_cache.self_id:
self.session.process_entities(types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])) process_entities = self.session.process_entities(
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
)
if inspect.isawaitable(process_entities):
await process_entities
ss, cs = self._message_box.session_state() ss, cs = self._message_box.session_state()
self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
if inspect.isawaitable(update_state):
await update_state
now = datetime.datetime.now() # any datetime works; channels don't need it now = datetime.datetime.now() # any datetime works; channels don't need it
for channel_id, pts in cs.items(): for channel_id, pts in cs.items():
self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) update_state = self.session.set_update_state(
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
)
if inspect.isawaitable(update_state):
await update_state
async def _disconnect_coro(self: 'TelegramClient'): async def _disconnect_coro(self: 'TelegramClient'):
if self.session is None: if self.session is None:
@ -730,9 +757,11 @@ class TelegramBaseClient(abc.ABC):
await asyncio.wait(self._event_handler_tasks) await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear() self._event_handler_tasks.clear()
self._save_states_and_entities() await self._save_states_and_entities()
self.session.close() close = self.session.close()
if inspect.isawaitable(close):
await close
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -753,22 +782,28 @@ class TelegramBaseClient(abc.ABC):
self._log[__name__].info('Reconnecting to new data center %s', new_dc) self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc) dc = await self._get_dc(new_dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(set_dc):
await set_dc
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
self.session.save() save = self.session.save()
if inspect.isawaitable(save):
await save
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
def _auth_key_callback(self: 'TelegramClient', auth_key): async def _auth_key_callback(self: 'TelegramClient', auth_key):
""" """
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self.session.auth_key = auth_key
self.session.save() save = self.session.save()
if inspect.isawaitable(save):
await save
# endregion # endregion
@ -895,7 +930,11 @@ class TelegramBaseClient(abc.ABC):
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port) if inspect.isawaitable(session):
session = await session
set_dc = session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(set_dc):
await set_dc
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')

View File

@ -289,7 +289,7 @@ class UpdateMethods:
len(self._mb_entity_cache), len(self._mb_entity_cache),
self._entity_cache_limit self._entity_cache_limit
) )
self._save_states_and_entities() await self._save_states_and_entities()
self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map) self._mb_entity_cache.retain(lambda id: id == self._mb_entity_cache.self_id or id in self._message_box.map)
if len(self._mb_entity_cache) >= self._entity_cache_limit: if len(self._mb_entity_cache) >= self._entity_cache_limit:
warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit') warnings.warn('in-memory entities exceed entity_cache_limit after flushing; consider setting a larger limit')
@ -343,7 +343,10 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for account updates') self._log[__name__].info('Got difference for account updates')
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) _preprocess_updates = self._preprocess_updates(updates, users, chats)
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates)
continue continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
@ -441,7 +444,10 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) _preprocess_updates = self._preprocess_updates(updates, users, chats)
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates)
continue continue
deadline = self._message_box.check_deadlines() deadline = self._message_box.check_deadlines()
@ -462,7 +468,10 @@ class UpdateMethods:
except GapError: except GapError:
continue # get(_channel)_difference will start returning requests continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(self._preprocess_updates(processed, users, chats)) _preprocess_updates = self._preprocess_updates(processed, users, chats)
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
@ -470,9 +479,11 @@ class UpdateMethods:
self._updates_error = e self._updates_error = e
await self.disconnect() await self.disconnect()
def _preprocess_updates(self, updates, users, chats): async def _preprocess_updates(self, updates, users, chats):
self._mb_entity_cache.extend(users, chats) self._mb_entity_cache.extend(users, chats)
self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))
if inspect.isawaitable(process_entities):
await process_entities
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)} for x in itertools.chain(users, chats)}
for u in updates: for u in updates:
@ -515,9 +526,11 @@ class UpdateMethods:
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
self._save_states_and_entities() await self._save_states_and_entities()
self.session.save() save = self.session.save()
if inspect.isawaitable(save):
await save
async def _dispatch_update(self: 'TelegramClient', update): async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this # TODO only used for AlbumHack, and MessageBox is not really designed for this

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import datetime import datetime
import inspect
import itertools import itertools
import time import time
import typing import typing
@ -80,7 +81,9 @@ class UserMethods:
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) process_entities = self.session.process_entities(result)
if inspect.isawaitable(process_entities):
await process_entities
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -90,7 +93,9 @@ class UserMethods:
return results return results
else: else:
result = await future result = await future
self.session.process_entities(result) process_entities = self.session.process_entities(result)
if inspect.isawaitable(process_entities):
await process_entities
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError, errors.InterdcCallErrorError, errors.RpcMcgetFailError, errors.InterdcCallErrorError,
@ -435,7 +440,10 @@ class UserMethods:
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
return self.session.get_input_entity(peer) input_entity = self.session.get_input_entity(peer)
if inspect.isawaitable(input_entity):
input_entity = await input_entity
return input_entity
except ValueError: except ValueError:
pass pass
@ -574,8 +582,10 @@ class UserMethods:
pass pass
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( input_entity = self.session.get_input_entity(string)
self.session.get_input_entity(string)) if inspect.isawaitable(input_entity):
input_entity = await input_entity
return await self.get_entity(input_entity)
except ValueError: except ValueError:
pass pass

View File

@ -302,7 +302,7 @@ class MTProtoSender:
# notify whenever we change it. This is crucial when we # notify whenever we change it. This is crucial when we
# switch to different data centers. # switch to different data centers.
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(self.auth_key) await self._auth_key_callback(self.auth_key)
self._log.debug('auth_key generation success!') self._log.debug('auth_key generation success!')
return True return True