mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-04 04:00:18 +03:00
Refactor maybe_async
This commit is contained in:
parent
4742414247
commit
13913099db
|
@ -552,9 +552,7 @@ class AuthMethods:
|
|||
self._authorized = False
|
||||
|
||||
await self.disconnect()
|
||||
delete = self.session.delete()
|
||||
if inspect.isawaitable(delete):
|
||||
await delete
|
||||
await utils.maybe_async(self.session.delete())
|
||||
self.session = None
|
||||
return True
|
||||
|
||||
|
|
|
@ -63,12 +63,12 @@ class _DirectDownloadIter(RequestIter):
|
|||
config = await self.client(functions.help.GetConfigRequest())
|
||||
for option in config.dc_options:
|
||||
if option.ip_address == self.client.session.server_address:
|
||||
set_dc = self.client.session.set_dc(option.id, option.ip_address, option.port)
|
||||
if inspect.isawaitable(set_dc):
|
||||
await set_dc
|
||||
save = self.client.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
await utils.maybe_async(
|
||||
self.client.session.set_dc(
|
||||
option.id, option.ip_address, option.port
|
||||
)
|
||||
)
|
||||
await utils.maybe_async(self.client.session.save())
|
||||
break
|
||||
|
||||
# TODO Figure out why the session may have the wrong DC ID
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import abc
|
||||
import inspect
|
||||
import re
|
||||
import asyncio
|
||||
import collections
|
||||
|
@ -10,7 +9,7 @@ import typing
|
|||
import datetime
|
||||
import pathlib
|
||||
|
||||
from .. import version, helpers, __name__ as __base_name__
|
||||
from .. import utils, version, helpers, __name__ as __base_name__
|
||||
from ..crypto import rsa
|
||||
from ..extensions import markdown
|
||||
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
|
||||
|
@ -540,16 +539,14 @@ class TelegramBaseClient(abc.ABC):
|
|||
# ':' in session.server_address is True if it's an IPv6 address
|
||||
if (not self.session.server_address or
|
||||
(':' in self.session.server_address) != self._use_ipv6):
|
||||
set_dc = self.session.set_dc(
|
||||
await utils.maybe_async(
|
||||
self.session.set_dc(
|
||||
DEFAULT_DC_ID,
|
||||
DEFAULT_IPV6_IP if self._use_ipv6 else DEFAULT_IPV4_IP,
|
||||
DEFAULT_PORT
|
||||
)
|
||||
if inspect.isawaitable(set_dc):
|
||||
await set_dc
|
||||
save = self.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
)
|
||||
await utils.maybe_async(self.session.save())
|
||||
|
||||
if not await self._sender.connect(self._connection(
|
||||
self.session.server_address,
|
||||
|
@ -563,19 +560,13 @@ class TelegramBaseClient(abc.ABC):
|
|||
return
|
||||
|
||||
self.session.auth_key = self._sender.auth_key
|
||||
save = self.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
await utils.maybe_async(self.session.save())
|
||||
|
||||
try:
|
||||
# See comment when saving entities to understand this hack
|
||||
self_entity = self.session.get_input_entity(0)
|
||||
if inspect.isawaitable(self_entity):
|
||||
self_entity = await self_entity
|
||||
self_entity = await utils.maybe_async(self.session.get_input_entity(0))
|
||||
self_id = self_entity.access_hash
|
||||
self_user = self.session.get_input_entity(self_id)
|
||||
if inspect.isawaitable(self_user):
|
||||
self_user = await self_user
|
||||
self_user = await utils.maybe_async(self.session.get_input_entity(self_id))
|
||||
self._mb_entity_cache.set_self_user(self_id, None, self_user.access_hash)
|
||||
except ValueError:
|
||||
pass
|
||||
|
@ -584,9 +575,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
|
||||
cs = []
|
||||
|
||||
update_states = self.session.get_update_states()
|
||||
if inspect.isawaitable(update_states):
|
||||
update_states = await update_states
|
||||
update_states = await utils.maybe_async(self.session.get_update_states())
|
||||
for entity_id, state in update_states:
|
||||
if entity_id == 0:
|
||||
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
|
||||
|
@ -597,9 +586,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
self._message_box.load(ss, cs)
|
||||
for state in cs:
|
||||
try:
|
||||
entity = self.session.get_input_entity(state.channel_id)
|
||||
if inspect.isawaitable(entity):
|
||||
entity = await entity
|
||||
entity = await utils.maybe_async(self.session.get_input_entity(state.channel_id))
|
||||
except ValueError:
|
||||
self._log[__name__].warning(
|
||||
'No access_hash in cache for channel %s, will not catch up', state.channel_id)
|
||||
|
@ -709,23 +696,21 @@ class TelegramBaseClient(abc.ABC):
|
|||
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
|
||||
# It doesn't matter if we put users in the list of chats.
|
||||
if self._mb_entity_cache.self_id:
|
||||
process_entities = self.session.process_entities(
|
||||
await utils.maybe_async(
|
||||
self.session.process_entities(
|
||||
types.contacts.ResolvedPeer(None, [types.InputPeerUser(0, self._mb_entity_cache.self_id)], [])
|
||||
)
|
||||
if inspect.isawaitable(process_entities):
|
||||
await process_entities
|
||||
)
|
||||
|
||||
ss, cs = self._message_box.session_state()
|
||||
update_state = self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
|
||||
if inspect.isawaitable(update_state):
|
||||
await update_state
|
||||
await utils.maybe_async(self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)))
|
||||
now = datetime.datetime.now() # any datetime works; channels don't need it
|
||||
for channel_id, pts in cs.items():
|
||||
update_state = self.session.set_update_state(
|
||||
await utils.maybe_async(
|
||||
self.session.set_update_state(
|
||||
channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)
|
||||
)
|
||||
if inspect.isawaitable(update_state):
|
||||
await update_state
|
||||
)
|
||||
|
||||
async def _disconnect_coro(self: 'TelegramClient'):
|
||||
if self.session is None:
|
||||
|
@ -759,9 +744,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
|
||||
await self._save_states_and_entities()
|
||||
|
||||
close = self.session.close()
|
||||
if inspect.isawaitable(close):
|
||||
await close
|
||||
await utils.maybe_async(self.session.close())
|
||||
|
||||
async def _disconnect(self: 'TelegramClient'):
|
||||
"""
|
||||
|
@ -782,16 +765,12 @@ class TelegramBaseClient(abc.ABC):
|
|||
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
|
||||
dc = await self._get_dc(new_dc)
|
||||
|
||||
set_dc = self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
if inspect.isawaitable(set_dc):
|
||||
await set_dc
|
||||
await utils.maybe_async(self.session.set_dc(dc.id, dc.ip_address, dc.port))
|
||||
# auth_key's are associated with a server, which has now changed
|
||||
# so it's not valid anymore. Set to None to force recreating it.
|
||||
self._sender.auth_key.key = None
|
||||
self.session.auth_key = None
|
||||
save = self.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
await utils.maybe_async(self.session.save())
|
||||
await self._disconnect()
|
||||
return await self.connect()
|
||||
|
||||
|
@ -801,9 +780,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
new authorization key. This means we are not authorized.
|
||||
"""
|
||||
self.session.auth_key = auth_key
|
||||
save = self.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
await utils.maybe_async(self.session.save())
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -929,12 +906,8 @@ class TelegramBaseClient(abc.ABC):
|
|||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
if not session:
|
||||
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = self.session.clone()
|
||||
if inspect.isawaitable(session):
|
||||
session = await session
|
||||
set_dc = session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
if inspect.isawaitable(set_dc):
|
||||
await set_dc
|
||||
session = await utils.maybe_async(self.session.clone())
|
||||
await utils.maybe_async(session.set_dc(dc.id, dc.ip_address, dc.port))
|
||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||
|
||||
self._log[__name__].info('Creating new CDN client')
|
||||
|
|
|
@ -343,9 +343,7 @@ class UpdateMethods:
|
|||
if updates:
|
||||
self._log[__name__].info('Got difference for account updates')
|
||||
|
||||
_preprocess_updates = self._preprocess_updates(updates, users, chats)
|
||||
if inspect.isawaitable(_preprocess_updates):
|
||||
_preprocess_updates = await _preprocess_updates
|
||||
_preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
|
||||
updates_to_dispatch.extend(_preprocess_updates)
|
||||
continue
|
||||
|
||||
|
@ -444,9 +442,7 @@ class UpdateMethods:
|
|||
if updates:
|
||||
self._log[__name__].info('Got difference for channel %d updates', get_diff.channel.channel_id)
|
||||
|
||||
_preprocess_updates = self._preprocess_updates(updates, users, chats)
|
||||
if inspect.isawaitable(_preprocess_updates):
|
||||
_preprocess_updates = await _preprocess_updates
|
||||
_preprocess_updates = await utils.maybe_async(self._preprocess_updates(updates, users, chats))
|
||||
updates_to_dispatch.extend(_preprocess_updates)
|
||||
continue
|
||||
|
||||
|
@ -468,9 +464,7 @@ class UpdateMethods:
|
|||
except GapError:
|
||||
continue # get(_channel)_difference will start returning requests
|
||||
|
||||
_preprocess_updates = self._preprocess_updates(processed, users, chats)
|
||||
if inspect.isawaitable(_preprocess_updates):
|
||||
_preprocess_updates = await _preprocess_updates
|
||||
_preprocess_updates = await utils.maybe_async(self._preprocess_updates(processed, users, chats))
|
||||
updates_to_dispatch.extend(_preprocess_updates)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
@ -481,9 +475,7 @@ class UpdateMethods:
|
|||
|
||||
async def _preprocess_updates(self, updates, users, chats):
|
||||
self._mb_entity_cache.extend(users, chats)
|
||||
process_entities = self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats))
|
||||
if inspect.isawaitable(process_entities):
|
||||
await process_entities
|
||||
await utils.maybe_async(self.session.process_entities(types.contacts.ResolvedPeer(None, users, chats)))
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(users, chats)}
|
||||
for u in updates:
|
||||
|
@ -528,9 +520,7 @@ class UpdateMethods:
|
|||
# it every minute instead. No-op if there's nothing new.
|
||||
await self._save_states_and_entities()
|
||||
|
||||
save = self.session.save()
|
||||
if inspect.isawaitable(save):
|
||||
await save
|
||||
await utils.maybe_async(self.session.save())
|
||||
|
||||
async def _dispatch_update(self: 'TelegramClient', update):
|
||||
# TODO only used for AlbumHack, and MessageBox is not really designed for this
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import inspect
|
||||
import itertools
|
||||
import time
|
||||
import typing
|
||||
|
@ -81,9 +80,7 @@ class UserMethods:
|
|||
exceptions.append(e)
|
||||
results.append(None)
|
||||
continue
|
||||
process_entities = self.session.process_entities(result)
|
||||
if inspect.isawaitable(process_entities):
|
||||
await process_entities
|
||||
await utils.maybe_async(self.session.process_entities(result))
|
||||
exceptions.append(None)
|
||||
results.append(result)
|
||||
request_index += 1
|
||||
|
@ -93,9 +90,7 @@ class UserMethods:
|
|||
return results
|
||||
else:
|
||||
result = await future
|
||||
process_entities = self.session.process_entities(result)
|
||||
if inspect.isawaitable(process_entities):
|
||||
await process_entities
|
||||
await utils.maybe_async(self.session.process_entities(result))
|
||||
return result
|
||||
except (errors.ServerError, errors.RpcCallFailError,
|
||||
errors.RpcMcgetFailError, errors.InterdcCallErrorError,
|
||||
|
@ -440,9 +435,7 @@ class UserMethods:
|
|||
|
||||
# No InputPeer, cached peer, or known string. Fetch from disk cache
|
||||
try:
|
||||
input_entity = self.session.get_input_entity(peer)
|
||||
if inspect.isawaitable(input_entity):
|
||||
input_entity = await input_entity
|
||||
input_entity = await utils.maybe_async(self.session.get_input_entity(peer))
|
||||
return input_entity
|
||||
except ValueError:
|
||||
pass
|
||||
|
@ -582,9 +575,7 @@ class UserMethods:
|
|||
pass
|
||||
try:
|
||||
# Nobody with this username, maybe it's an exact name/title
|
||||
input_entity = self.session.get_input_entity(string)
|
||||
if inspect.isawaitable(input_entity):
|
||||
input_entity = await input_entity
|
||||
input_entity = await utils.maybe_async(self.session.get_input_entity(string))
|
||||
return await self.get_entity(input_entity)
|
||||
except ValueError:
|
||||
pass
|
||||
|
|
|
@ -1557,3 +1557,10 @@ def _photo_size_byte_count(size):
|
|||
return max(size.sizes)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def maybe_async(coro):
|
||||
result = coro
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue
Block a user