Revert "Make sessions async"

This reverts commit d2de0f3aca.
This commit is contained in:
Lonami Exo 2022-08-30 12:32:21 +02:00
parent 88b2b9372d
commit 7d21b40401
9 changed files with 78 additions and 61 deletions

View File

@ -614,7 +614,7 @@ class AuthMethods:
self._authorized = False self._authorized = False
await self.disconnect() await self.disconnect()
await self.session.delete() self.session.delete()
return True return True
async def edit_2fa( async def edit_2fa(

View File

@ -55,7 +55,7 @@ class _DirectDownloadIter(RequestIter):
if option.ip_address == self.client.session.server_address: if option.ip_address == self.client.session.server_address:
self.client.session.set_dc( self.client.session.set_dc(
option.id, option.ip_address, option.port) option.id, option.ip_address, option.port)
await self.client.session.save() self.client.session.save()
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID
@ -402,7 +402,7 @@ class DownloadMethods:
if isinstance(message.action, if isinstance(message.action,
types.MessageActionChatEditPhoto): types.MessageActionChatEditPhoto):
media = media.photo media = media.photo
if isinstance(media, types.MessageMediaWebPage): if isinstance(media, types.MessageMediaWebPage):
if isinstance(media.webpage, types.WebPage): if isinstance(media.webpage, types.WebPage):
media = media.webpage.document or media.webpage.photo media = media.webpage.document or media.webpage.photo

View File

@ -1019,7 +1019,7 @@ class MessageMethods:
async def edit_message( async def edit_message(
self: 'TelegramClient', self: 'TelegramClient',
entity: 'typing.Union[hints.EntityLike, types.Message]', entity: 'typing.Union[hints.EntityLike, types.Message]',
message: 'hints.MessageIDLike' = None, message: 'hints.MessageLike' = None,
text: str = None, text: str = None,
*, *,
parse_mode: str = (), parse_mode: str = (),

View File

@ -398,6 +398,11 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = StateCache(
self.session.get_update_state(0), self._log)
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -535,13 +540,13 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
await self.session.save() self.session.save()
if self._catch_up: if self._catch_up:
ss = SessionState(0, 0, False, 0, 0, 0, 0, None) ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
cs = [] cs = []
for entity_id, state in await self.session.get_update_states(): for entity_id, state in self.session.get_update_states():
if entity_id == 0: if entity_id == 0:
# TODO current session doesn't store self-user info but adding that is breaking on downstream session impls # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None) ss = SessionState(0, 0, False, state.pts, state.qts, int(state.date.timestamp()), state.seq, None)
@ -550,7 +555,7 @@ class TelegramBaseClient(abc.ABC):
self._message_box.load(ss, cs) self._message_box.load(ss, cs)
for state in cs: for state in cs:
entity = await self.session.get_input_entity(state.channel_id) entity = self.session.get_input_entity(state.channel_id)
if entity: if entity:
self._mb_entity_cache.put(Entity(EntityType.CHANNEL, entity.channel_id, entity.access_hash)) self._mb_entity_cache.put(Entity(EntityType.CHANNEL, entity.channel_id, entity.access_hash))
@ -670,15 +675,15 @@ class TelegramBaseClient(abc.ABC):
# Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities. # Piggy-back on an arbitrary TL type with users and chats so the session can understand to read the entities.
# It doesn't matter if we put users in the list of chats. # It doesn't matter if we put users in the list of chats.
await self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], [])) self.session.process_entities(types.contacts.ResolvedPeer(None, [e._as_input_peer() for e in entities], []))
ss, cs = self._message_box.session_state() ss, cs = self._message_box.session_state()
await self.session.set_update_state(0, types.updates.State(**ss, unread_count=0)) self.session.set_update_state(0, types.updates.State(**ss, unread_count=0))
now = datetime.datetime.now() # any datetime works; channels don't need it now = datetime.datetime.now() # any datetime works; channels don't need it
for channel_id, pts in cs.items(): for channel_id, pts in cs.items():
await self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0)) self.session.set_update_state(channel_id, types.updates.State(pts, 0, now, 0, unread_count=0))
await self.session.close() self.session.close()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -704,17 +709,17 @@ class TelegramBaseClient(abc.ABC):
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
await self.session.save() self.session.save()
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
async def _auth_key_callback(self: 'TelegramClient', auth_key): def _auth_key_callback(self: 'TelegramClient', auth_key):
""" """
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self.session.auth_key = auth_key
await self.session.save() self.session.save()
# endregion # endregion

View File

@ -397,11 +397,7 @@ class UpdateMethods:
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
try: self.session.save()
await self.session.save()
except OSError as e:
# No big deal if this cannot be immediately saved
self._log[__name__].warning('Could not perform the periodic save of session data: %s: %s', type(e), e)
async def _dispatch_update(self: 'TelegramClient', update): async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this # TODO only used for AlbumHack, and MessageBox is not really designed for this

View File

@ -71,11 +71,7 @@ class UserMethods:
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
try: self.session.process_entities(result)
await self.session.process_entities(result)
except OSError as e:
self._log[__name__].warning(
'Failed to save possibly new entities to the session: %s: %s', type(e), e)
self._entity_cache.add(result) self._entity_cache.add(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
@ -86,14 +82,7 @@ class UserMethods:
return results return results
else: else:
result = await future result = await future
# This is called pretty often, and it's okay if it fails every now and then. self.session.process_entities(result)
# It only means certain entities won't be saved.
try:
await self.session.process_entities(result)
except OSError as e:
self._log[__name__].warning(
'Failed to save possibly new entities to the session: %s: %s', type(e), e)
self._entity_cache.add(result) self._entity_cache.add(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
@ -438,7 +427,7 @@ class UserMethods:
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
return await self.session.get_input_entity(peer) return self.session.get_input_entity(peer)
except ValueError: except ValueError:
pass pass
@ -578,7 +567,7 @@ class UserMethods:
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( return await self.get_entity(
await self.session.get_input_entity(string)) self.session.get_input_entity(string))
except ValueError: except ValueError:
pass pass

View File

@ -296,7 +296,7 @@ class MTProtoSender:
# notify whenever we change it. This is crucial when we # notify whenever we change it. This is crucial when we
# switch to different data centers. # switch to different data centers.
if self._auth_key_callback: if self._auth_key_callback:
await self._auth_key_callback(self.auth_key) self._auth_key_callback(self.auth_key)
self._log.debug('auth_key generation success!') self._log.debug('auth_key generation success!')
return True return True

View File

@ -79,7 +79,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def get_update_state(self, entity_id): def get_update_state(self, entity_id):
""" """
Returns the ``UpdateState`` associated with the given `entity_id`. Returns the ``UpdateState`` associated with the given `entity_id`.
If the `entity_id` is 0, it should return the ``UpdateState`` for If the `entity_id` is 0, it should return the ``UpdateState`` for
@ -89,7 +89,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
""" """
Sets the given ``UpdateState`` for the specified `entity_id`, which Sets the given ``UpdateState`` for the specified `entity_id`, which
should be 0 if the ``UpdateState`` is the "general" state (and not should be 0 if the ``UpdateState`` is the "general" state (and not
@ -103,15 +103,14 @@ class Session(ABC):
Returns an iterable over all known pairs of ``(entity ID, update state)``. Returns an iterable over all known pairs of ``(entity ID, update state)``.
""" """
@abstractmethod def close(self):
async def close(self):
""" """
Called on client disconnection. Should be used to Called on client disconnection. Should be used to
free any used resources. Can be left empty if none. free any used resources. Can be left empty if none.
""" """
@abstractmethod @abstractmethod
async def save(self): def save(self):
""" """
Called whenever important properties change. It should Called whenever important properties change. It should
make persist the relevant session information to disk. make persist the relevant session information to disk.
@ -119,15 +118,22 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def delete(self): def delete(self):
""" """
Called upon client.log_out(). Should delete the stored Called upon client.log_out(). Should delete the stored
information from disk since it's not valid anymore. information from disk since it's not valid anymore.
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def list_sessions(cls):
"""
Lists available sessions. Not used by the library itself.
"""
return []
@abstractmethod @abstractmethod
async def process_entities(self, tlo): def process_entities(self, tlo):
""" """
Processes the input ``TLObject`` or ``list`` and saves Processes the input ``TLObject`` or ``list`` and saves
whatever information is relevant (e.g., ID or access hash). whatever information is relevant (e.g., ID or access hash).
@ -135,7 +141,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def get_input_entity(self, key): def get_input_entity(self, key):
""" """
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``).
The library uses this method whenever an ``InputPeer`` is needed The library uses this method whenever an ``InputPeer`` is needed
@ -143,3 +149,24 @@ class Session(ABC):
to use a cached username to avoid extra RPC). to use a cached username to avoid extra RPC).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
"""
Caches the given file information persistently, so that it
doesn't need to be re-uploaded in case the file is used again.
The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``,
both with an ``.id`` and ``.access_hash`` attributes.
"""
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
"""
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size``
match an existing saved record. The class will either be an
``InputPhoto`` or ``InputDocument``, both with two parameters
``id`` and ``access_hash`` in that order.
"""
raise NotImplementedError

View File

@ -71,22 +71,22 @@ class MemorySession(Session):
def takeout_id(self, value): def takeout_id(self, value):
self._takeout_id = value self._takeout_id = value
async def get_update_state(self, entity_id): def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)
async def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state self._update_states[entity_id] = state
async def get_update_states(self): async def get_update_states(self):
return self._update_states.items() return self._update_states.items()
async def close(self): def close(self):
pass pass
async def save(self): def save(self):
pass pass
async def delete(self): def delete(self):
pass pass
@staticmethod @staticmethod
@ -147,31 +147,31 @@ class MemorySession(Session):
rows.append(row) rows.append(row)
return rows return rows
async def process_entities(self, tlo): def process_entities(self, tlo):
self._entities |= set(self._entities_to_rows(tlo)) self._entities |= set(self._entities_to_rows(tlo))
async def get_entity_rows_by_phone(self, phone): def get_entity_rows_by_phone(self, phone):
try: try:
return next((id, hash) for id, hash, _, found_phone, _ return next((id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone) in self._entities if found_phone == phone)
except StopIteration: except StopIteration:
pass pass
async def get_entity_rows_by_username(self, username): def get_entity_rows_by_username(self, username):
try: try:
return next((id, hash) for id, hash, found_username, _, _ return next((id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username) in self._entities if found_username == username)
except StopIteration: except StopIteration:
pass pass
async def get_entity_rows_by_name(self, name): def get_entity_rows_by_name(self, name):
try: try:
return next((id, hash) for id, hash, _, _, found_name return next((id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name) in self._entities if found_name == name)
except StopIteration: except StopIteration:
pass pass
async def get_entity_rows_by_id(self, id, exact=True): def get_entity_rows_by_id(self, id, exact=True):
try: try:
if exact: if exact:
return next((id, hash) for found_id, hash, _, _, _ return next((id, hash) for found_id, hash, _, _, _
@ -187,7 +187,7 @@ class MemorySession(Session):
except StopIteration: except StopIteration:
pass pass
async def get_input_entity(self, key): def get_input_entity(self, key):
try: try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
@ -207,21 +207,21 @@ class MemorySession(Session):
if isinstance(key, str): if isinstance(key, str):
phone = utils.parse_phone(key) phone = utils.parse_phone(key)
if phone: if phone:
result = await self.get_entity_rows_by_phone(phone) result = self.get_entity_rows_by_phone(phone)
else: else:
username, invite = utils.parse_username(key) username, invite = utils.parse_username(key)
if username and not invite: if username and not invite:
result = await self.get_entity_rows_by_username(username) result = self.get_entity_rows_by_username(username)
else: else:
tup = utils.resolve_invite_link(key)[1] tup = utils.resolve_invite_link(key)[1]
if tup: if tup:
result = await self.get_entity_rows_by_id(tup, exact=False) result = self.get_entity_rows_by_id(tup, exact=False)
elif isinstance(key, int): elif isinstance(key, int):
result = await self.get_entity_rows_by_id(key, exact) result = self.get_entity_rows_by_id(key, exact)
if not result and isinstance(key, str): if not result and isinstance(key, str):
result = await self.get_entity_rows_by_name(key) result = self.get_entity_rows_by_name(key)
if result: if result:
entity_id, entity_hash = result # unpack resulting tuple entity_id, entity_hash = result # unpack resulting tuple
@ -236,14 +236,14 @@ class MemorySession(Session):
else: else:
raise ValueError('Could not find input entity with key ', key) raise ValueError('Could not find input entity with key ', key)
async def cache_file(self, md5_digest, file_size, instance): def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(type(instance))) key = (md5_digest, file_size, _SentFileType.from_type(type(instance)))
value = (instance.id, instance.access_hash) value = (instance.id, instance.access_hash)
self._files[key] = value self._files[key] = value
async def get_file(self, md5_digest, file_size, cls): def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls)) key = (md5_digest, file_size, _SentFileType.from_type(cls))
try: try:
return cls(*self._files[key]) return cls(*self._files[key])