From 13913099db44e0f5e6507487549de904314c6661 Mon Sep 17 00:00:00 2001 From: humbertogontijo <96875138+humbertogontijo@users.noreply.github.com> Date: Sat, 26 Jul 2025 09:31:59 -0300 Subject: [PATCH] Refactor maybe_async --- telethon/client/auth.py | 4 +- telethon/client/downloads.py | 12 ++-- telethon/client/telegrambaseclient.py | 83 +++++++++------------------ telethon/client/updates.py | 20 ++----- telethon/client/users.py | 17 ++---- telethon/utils.py | 7 +++ 6 files changed, 51 insertions(+), 92 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 8428a192..0c9eddb0 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -552,9 +552,7 @@ class AuthMethods: self._authorized = False await self.disconnect() - delete = self.session.delete() - if inspect.isawaitable(delete): - await delete + await utils.maybe_async(self.session.delete()) self.session = None return True diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 809e5ca3..a3e7dd41 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -63,12 +63,12 @@ class _DirectDownloadIter(RequestIter): config = await self.client(functions.help.GetConfigRequest()) for option in config.dc_options: if option.ip_address == self.client.session.server_address: - set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port) - if inspect.isawaitable(set_dc): - await set_dc - save = self.client.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async( + self.client.session.set_dc( + option.id, option.ip_address, option.port + ) + ) + await utils.maybe_async(self.client.session.save()) break # TODO Figure out why the session may have the wrong DC ID diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 14d40b7c..3c2ec9df 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -1,5 +1,4 @@ import abc -import inspect import re import asyncio import collections @@ -10,7 +9,7 @@ import typing import datetime import pathlib -from .. import version, helpers, __name__ as __base_name__ +from .. import utils, version, helpers, __name__ as __base_name__ from ..crypto import rsa from ..extensions import markdown from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy @@ -540,16 +539,14 @@ class TelegramBaseClient(abc.ABC): # ':' in session.server_address is True if it's an IPv6 address if (not self.session.server_address or (':' in self.session.server_address) != self._use_ipv6): - set_dc = self.session.set_dc( - DEFAULT_DC_ID, - DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, - DEFAULT_PORT + await utils.maybe_async( + self.session.set_dc( + DEFAULT_DC_ID, + DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP, + DEFAULT_PORT + ) ) - if inspect.isawaitable(set_dc): - await set_dc - save = self.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async(self.session.save()) if not await self._sender.connect(self._connection( self.session.server_address, @@ -563,19 +560,13 @@ class TelegramBaseClient(abc.ABC): return self.session.auth_key = self._sender.auth_key - save = self.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async(self.session.save()) try: # See comment when saving entities to understand this hack - self_entity = self.session.get_input_entity(0) - if inspect.isawaitable(self_entity): - self_entity = await self_entity + self_entity = await utils.maybe_async(self.session.get_input_entity(0)) self_id = self_entity.access_hash - self_user = self.session.get_input_entity(self_id) - if inspect.isawaitable(self_user): - self_user = await self_user + self_user = await utils.maybe_async(self.session.get_input_entity(self_id)) self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash) except ValueError: pass @@ -584,9 +575,7 @@ class TelegramBaseClient(abc.ABC): ss = SessionState(0, 0, False, 0, 0, 0, 0, None) cs = [] - update_states = self.session.get_update_states() - if inspect.isawaitable(update_states): - update_states = await update_states + update_states = await utils.maybe_async(self.session.get_update_states()) for entity_id, state in update_states: if entity_id == 0: # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls @@ -597,9 +586,7 @@ class TelegramBaseClient(abc.ABC): self._message_box.load(ss, cs) for state in cs: try: - entity = self.session.get_input_entity(state.channel_id) - if inspect.isawaitable(entity): - entity = await entity + entity = await utils.maybe_async(self.session.get_input_entity(state.channel_id)) except ValueError: self._log[__name__].warning( 'No access_hash in cache for channel %s, will not catch up', state.channel_id) @@ -709,23 +696,21 @@ class TelegramBaseClient(abc.ABC): # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # It doesn't matter if we put users in the list of chats. if self._mb_entity_cache.self_id: - process_entities = self.session.process_entities( - types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) + await utils.maybe_async( + self.session.process_entities( + types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], []) + ) ) - if inspect.isawaitable(process_entities): - await process_entities ss, cs = self._message_box.session_state() - update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) - if inspect.isawaitable(update_state): - await update_state + await utils.maybe_async(self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))) now = datetime.datetime.now() # any datetime works; channels don't need it for channel_id, pts in cs.items(): - update_state = self.session.set_update_state( - channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) + await utils.maybe_async( + self.session.set_update_state( + channel_id, types.updates.State(pts, 0, now, 0, unread_count=0) + ) ) - if inspect.isawaitable(update_state): - await update_state async def _disconnect_coro(self: 'TelegramClient'): if self.session is None: @@ -759,9 +744,7 @@ class TelegramBaseClient(abc.ABC): await self._save_states_and_entities() - close = self.session.close() - if inspect.isawaitable(close): - await close + await utils.maybe_async(self.session.close()) async def _disconnect(self: 'TelegramClient'): """ @@ -782,16 +765,12 @@ class TelegramBaseClient(abc.ABC): self._log[__name__].info('Reconnecting to new data center %s', new_dc) dc = await self._get_dc(new_dc) - set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port) - if inspect.isawaitable(set_dc): - await set_dc + await utils.maybe_async(self.session.set_dc(dc.id, dc.ip_address, dc.port)) # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. self._sender.auth_key.key = None self.session.auth_key = None - save = self.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async(self.session.save()) await self._disconnect() return await self.connect() @@ -801,9 +780,7 @@ class TelegramBaseClient(abc.ABC): new authorization key. This means we are not authorized. """ self.session.auth_key = auth_key - save = self.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async(self.session.save()) # endregion @@ -929,12 +906,8 @@ class TelegramBaseClient(abc.ABC): session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) - session = self.session.clone() - if inspect.isawaitable(session): - session = await session - set_dc = session.set_dc(dc.id, dc.ip_address, dc.port) - if inspect.isawaitable(set_dc): - await set_dc + session = await utils.maybe_async(self.session.clone()) + await utils.maybe_async(session.set_dc(dc.id, dc.ip_address, dc.port)) self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 9aa57e14..27410ce8 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -343,9 +343,7 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for account updates') - _preprocess_updates = self._preprocess_updates(updates, users, chats) - if inspect.isawaitable(_preprocess_updates): - _preprocess_updates = await _preprocess_updates + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats)) updates_to_dispatch.extend(_preprocess_updates) continue @@ -444,9 +442,7 @@ class UpdateMethods: if updates: self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id) - _preprocess_updates = self._preprocess_updates(updates, users, chats) - if inspect.isawaitable(_preprocess_updates): - _preprocess_updates = await _preprocess_updates + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats)) updates_to_dispatch.extend(_preprocess_updates) continue @@ -468,9 +464,7 @@ class UpdateMethods: except GapError: continue # get(_channel)_difference will start returning requests - _preprocess_updates = self._preprocess_updates(processed, users, chats) - if inspect.isawaitable(_preprocess_updates): - _preprocess_updates = await _preprocess_updates + _preprocess_updates = await utils.maybe_async(self._preprocess_updates(processed, users, chats)) updates_to_dispatch.extend(_preprocess_updates) except asyncio.CancelledError: pass @@ -481,9 +475,7 @@ class UpdateMethods: async def _preprocess_updates(self, updates, users, chats): self._mb_entity_cache.extend(users, chats) - process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)) - if inspect.isawaitable(process_entities): - await process_entities + await utils.maybe_async(self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))) entities = {utils.get_peer_id(x): x for x in itertools.chain(users, chats)} for u in updates: @@ -528,9 +520,7 @@ class UpdateMethods: # it every minute instead. No-op if there's nothing new. await self._save_states_and_entities() - save = self.session.save() - if inspect.isawaitable(save): - await save + await utils.maybe_async(self.session.save()) async def _dispatch_update(self: 'TelegramClient', update): # TODO only used for AlbumHack, and MessageBox is not really designed for this diff --git a/telethon/client/users.py b/telethon/client/users.py index 88b344c6..994856bb 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -1,6 +1,5 @@ import asyncio import datetime -import inspect import itertools import time import typing @@ -81,9 +80,7 @@ class UserMethods: exceptions.append(e) results.append(None) continue - process_entities = self.session.process_entities(result) - if inspect.isawaitable(process_entities): - await process_entities + await utils.maybe_async(self.session.process_entities(result)) exceptions.append(None) results.append(result) request_index += 1 @@ -93,9 +90,7 @@ class UserMethods: return results else: result = await future - process_entities = self.session.process_entities(result) - if inspect.isawaitable(process_entities): - await process_entities + await utils.maybe_async(self.session.process_entities(result)) return result except (errors.ServerError, errors.RpcCallFailError, errors.RpcMcgetFailError, errors.InterdcCallErrorError, @@ -440,9 +435,7 @@ class UserMethods: # No InputPeer, cached peer, or known string. Fetch from disk cache try: - input_entity = self.session.get_input_entity(peer) - if inspect.isawaitable(input_entity): - input_entity = await input_entity + input_entity = await utils.maybe_async(self.session.get_input_entity(peer)) return input_entity except ValueError: pass @@ -582,9 +575,7 @@ class UserMethods: pass try: # Nobody with this username, maybe it's an exact name/title - input_entity = self.session.get_input_entity(string) - if inspect.isawaitable(input_entity): - input_entity = await input_entity + input_entity = await utils.maybe_async(self.session.get_input_entity(string)) return await self.get_entity(input_entity) except ValueError: pass diff --git a/telethon/utils.py b/telethon/utils.py index c0c48d4d..d4553b6d 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -1557,3 +1557,10 @@ def _photo_size_byte_count(size): return max(size.sizes) else: return None + + +async def maybe_async(coro): + result = coro + if inspect.isawaitable(result): + result = await result + return result