Refactor maybe_async

This commit is contained in:
humbertogontijo 2025-07-26 09:31:59 -03:00
parent 4742414247
commit 13913099db
6 changed files with 51 additions and 92 deletions

View File

@ -552,9 +552,7 @@ class AuthMethods:
self._authorized = False self._authorized = False
await self.disconnect() await self.disconnect()
delete = self.session.delete() await utils.maybe_async(self.session.delete())
if inspect.isawaitable(delete):
await delete
self.session = None self.session = None
return True return True

View File

@ -63,12 +63,12 @@ class _DirectDownloadIter(RequestIter):
config = await self.client(functions.help.GetConfigRequest()) config = await self.client(functions.help.GetConfigRequest())
for option in config.dc_options: for option in config.dc_options:
if option.ip_address == self.client.session.server_address: if option.ip_address == self.client.session.server_address:
set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port) await utils.maybe_async(
if inspect.isawaitable(set_dc): self.client.session.set_dc(
await set_dc option.id, option.ip_address, option.port
save = self.client.session.save() )
if inspect.isawaitable(save): )
await save await utils.maybe_async(self.client.session.save())
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID

View File

@ -1,5 +1,4 @@
import abc import abc
import inspect
import re import re
import asyncio import asyncio
import collections import collections
@ -10,7 +9,7 @@ import typing
import datetime import datetime
import pathlib import pathlib
from .. import version, helpers, __name__ as __base_name__ from .. import utils, version, helpers, __name__ as __base_name__
from ..crypto import rsa from ..crypto import rsa
from ..extensions import markdown from ..extensions import markdown
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
@ -540,16 +539,14 @@ class TelegramBaseClient(abc.ABC):
# ':' in session.server_address is True if it's an IPv6 address # ':' in session.server_address is True if it's an IPv6 address
if (not self.session.server_address or if (not self.session.server_address or
(':' in self.session.server_address) != self._use_ipv6): (':' in self.session.server_address) != self._use_ipv6):
set_dc = self.session.set_dc( await utils.maybe_async(
self.session.set_dc(
DEFAULT_DC_ID, DEFAULT_DC_ID,
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
DEFAULT_PORT DEFAULT_PORT
) )
if inspect.isawaitable(set_dc): )
await set_dc await utils.maybe_async(self.session.save())
save = self.session.save()
if inspect.isawaitable(save):
await save
if not await self._sender.connect(self._connection( if not await self._sender.connect(self._connection(
self.session.server_address, self.session.server_address,
@ -563,19 +560,13 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
save = self.session.save() await utils.maybe_async(self.session.save())
if inspect.isawaitable(save):
await save
try: try:
# See comment when saving entities to understand this hack # See comment when saving entities to understand this hack
self_entity = self.session.get_input_entity(0) self_entity = await utils.maybe_async(self.session.get_input_entity(0))
if inspect.isawaitable(self_entity):
self_entity = await self_entity
self_id = self_entity.access_hash self_id = self_entity.access_hash
self_user = self.session.get_input_entity(self_id) self_user = await utils.maybe_async(self.session.get_input_entity(self_id))
if inspect.isawaitable(self_user):
self_user = await self_user
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
except ValueError: except ValueError:
pass pass
@ -584,9 +575,7 @@ class TelegramBaseClient(abc.ABC):
ss = SessionState(0, 0, False, 0, 0, 0, 0, None) ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = [] cs = []
update_states = self.session.get_update_states() update_states = await utils.maybe_async(self.session.get_update_states())
if inspect.isawaitable(update_states):
update_states = await update_states
for entity_id, state in update_states: for entity_id, state in update_states:
if entity_id == 0: if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
@ -597,9 +586,7 @@ class TelegramBaseClient(abc.ABC):
self._message_box.load(ss, cs) self._message_box.load(ss, cs)
for state in cs: for state in cs:
try: try:
entity = self.session.get_input_entity(state.channel_id) entity = await utils.maybe_async(self.session.get_input_entity(state.channel_id))
if inspect.isawaitable(entity):
entity = await entity
except ValueError: except ValueError:
self._log[__name__].warning( self._log[__name__].warning(
'No access_hash in cache for channel %s, will not catch up', state.channel_id) 'No access_hash in cache for channel %s, will not catch up', state.channel_id)
@ -709,23 +696,21 @@ class TelegramBaseClient(abc.ABC):
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats. # It doesn't matter if we put users in the list of chats.
if self._mb_entity_cache.self_id: if self._mb_entity_cache.self_id:
process_entities = self.session.process_entities( await utils.maybe_async(
self.session.process_entities(
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
) )
if inspect.isawaitable(process_entities): )
await process_entities
ss, cs = self._message_box.session_state() ss, cs = self._message_box.session_state()
update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) await utils.maybe_async(self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)))
if inspect.isawaitable(update_state):
await update_state
now = datetime.datetime.now() # any datetime works; channels don't need it now = datetime.datetime.now() # any datetime works; channels don't need it
for channel_id, pts in cs.items(): for channel_id, pts in cs.items():
update_state = self.session.set_update_state( await utils.maybe_async(
self.session.set_update_state(
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
) )
if inspect.isawaitable(update_state): )
await update_state
async def _disconnect_coro(self: 'TelegramClient'): async def _disconnect_coro(self: 'TelegramClient'):
if self.session is None: if self.session is None:
@ -759,9 +744,7 @@ class TelegramBaseClient(abc.ABC):
await self._save_states_and_entities() await self._save_states_and_entities()
close = self.session.close() await utils.maybe_async(self.session.close())
if inspect.isawaitable(close):
await close
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -782,16 +765,12 @@ class TelegramBaseClient(abc.ABC):
self._log[__name__].info('Reconnecting to new data center %s', new_dc) self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await self._get_dc(new_dc) dc = await self._get_dc(new_dc)
set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port) await utils.maybe_async(self.session.set_dc(dc.id, dc.ip_address, dc.port))
if inspect.isawaitable(set_dc):
await set_dc
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
save = self.session.save() await utils.maybe_async(self.session.save())
if inspect.isawaitable(save):
await save
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
@ -801,9 +780,7 @@ class TelegramBaseClient(abc.ABC):
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self.session.auth_key = auth_key
save = self.session.save() await utils.maybe_async(self.session.save())
if inspect.isawaitable(save):
await save
# endregion # endregion
@ -929,12 +906,8 @@ class TelegramBaseClient(abc.ABC):
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = await utils.maybe_async(self.session.clone())
if inspect.isawaitable(session): await utils.maybe_async(session.set_dc(dc.id, dc.ip_address, dc.port))
session = await session
set_dc = session.set_dc(dc.id, dc.ip_address, dc.port)
if inspect.isawaitable(set_dc):
await set_dc
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')

View File

@ -343,9 +343,7 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for account updates') self._log[__name__].info('Got difference for account updates')
_preprocess_updates = self._preprocess_updates(updates, users, chats) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates) updates_to_dispatch.extend(_preprocess_updates)
continue continue
@ -444,9 +442,7 @@ class UpdateMethods:
if updates: if updates:
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
_preprocess_updates = self._preprocess_updates(updates, users, chats) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates) updates_to_dispatch.extend(_preprocess_updates)
continue continue
@ -468,9 +464,7 @@ class UpdateMethods:
except GapError: except GapError:
continue # get(_channel)_difference will start returning requests continue # get(_channel)_difference will start returning requests
_preprocess_updates = self._preprocess_updates(processed, users, chats) _preprocess_updates = await utils.maybe_async(self._preprocess_updates(processed, users, chats))
if inspect.isawaitable(_preprocess_updates):
_preprocess_updates = await _preprocess_updates
updates_to_dispatch.extend(_preprocess_updates) updates_to_dispatch.extend(_preprocess_updates)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -481,9 +475,7 @@ class UpdateMethods:
async def _preprocess_updates(self, updates, users, chats): async def _preprocess_updates(self, updates, users, chats):
self._mb_entity_cache.extend(users, chats) self._mb_entity_cache.extend(users, chats)
process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) await utils.maybe_async(self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)))
if inspect.isawaitable(process_entities):
await process_entities
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)} for x in itertools.chain(users, chats)}
for u in updates: for u in updates:
@ -528,9 +520,7 @@ class UpdateMethods:
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
await self._save_states_and_entities() await self._save_states_and_entities()
save = self.session.save() await utils.maybe_async(self.session.save())
if inspect.isawaitable(save):
await save
async def _dispatch_update(self: 'TelegramClient', update): async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this # TODO only used for AlbumHack, and MessageBox is not really designed for this

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import datetime import datetime
import inspect
import itertools import itertools
import time import time
import typing import typing
@ -81,9 +80,7 @@ class UserMethods:
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
process_entities = self.session.process_entities(result) await utils.maybe_async(self.session.process_entities(result))
if inspect.isawaitable(process_entities):
await process_entities
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -93,9 +90,7 @@ class UserMethods:
return results return results
else: else:
result = await future result = await future
process_entities = self.session.process_entities(result) await utils.maybe_async(self.session.process_entities(result))
if inspect.isawaitable(process_entities):
await process_entities
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError, errors.InterdcCallErrorError, errors.RpcMcgetFailError, errors.InterdcCallErrorError,
@ -440,9 +435,7 @@ class UserMethods:
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
input_entity = self.session.get_input_entity(peer) input_entity = await utils.maybe_async(self.session.get_input_entity(peer))
if inspect.isawaitable(input_entity):
input_entity = await input_entity
return input_entity return input_entity
except ValueError: except ValueError:
pass pass
@ -582,9 +575,7 @@ class UserMethods:
pass pass
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
input_entity = self.session.get_input_entity(string) input_entity = await utils.maybe_async(self.session.get_input_entity(string))
if inspect.isawaitable(input_entity):
input_entity = await input_entity
return await self.get_entity(input_entity) return await self.get_entity(input_entity)
except ValueError: except ValueError:
pass pass

View File

@ -1557,3 +1557,10 @@ def _photo_size_byte_count(size):
return max(size.sizes) return max(size.sizes)
else: else:
return None return None
async def maybe_async(coro):
result = coro
if inspect.isawaitable(result):
result = await result
return result