Make sessions async

SQLiteSession is not updated, don't try to use it
This commit is contained in:
Tulir Asokan 2021-12-20 18:52:02 +02:00 committed by Lonami Exo
parent 43f629f665
commit d2de0f3aca
9 changed files with 54 additions and 80 deletions

View File

@ -597,7 +597,7 @@ class AuthMethods:
self._state_cache.reset() self._state_cache.reset()
await self.disconnect() await self.disconnect()
self.session.delete() await self.session.delete()
return True return True
async def edit_2fa( async def edit_2fa(

View File

@ -55,7 +55,7 @@ class _DirectDownloadIter(RequestIter):
if option.ip_address == self.client.session.server_address: if option.ip_address == self.client.session.server_address:
self.client.session.set_dc( self.client.session.set_dc(
option.id, option.ip_address, option.port) option.id, option.ip_address, option.port)
self.client.session.save() await self.client.session.save()
break break
# TODO Figure out why the session may have the wrong DC ID # TODO Figure out why the session may have the wrong DC ID
@ -402,7 +402,7 @@ class DownloadMethods:
if isinstance(message.action, if isinstance(message.action,
types.MessageActionChatEditPhoto): types.MessageActionChatEditPhoto):
media = media.photo media = media.photo
if isinstance(media, types.MessageMediaWebPage): if isinstance(media, types.MessageMediaWebPage):
if isinstance(media.webpage, types.WebPage): if isinstance(media.webpage, types.WebPage):
media = media.webpage.document or media.webpage.photo media = media.webpage.document or media.webpage.photo

View File

@ -1019,7 +1019,7 @@ class MessageMethods:
async def edit_message( async def edit_message(
self: 'TelegramClient', self: 'TelegramClient',
entity: 'typing.Union[hints.EntityLike, types.Message]', entity: 'typing.Union[hints.EntityLike, types.Message]',
message: 'hints.MessageLike' = None, message: 'hints.MessageIDLike' = None,
text: str = None, text: str = None,
*, *,
parse_mode: str = (), parse_mode: str = (),

View File

@ -412,10 +412,7 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection) self._state_cache = StateCache(None, self._log)
# TODO Get state from channels too
self._state_cache = StateCache(
self.session.get_update_state(0), self._log)
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -522,6 +519,11 @@ class TelegramBaseClient(abc.ABC):
except OSError: except OSError:
print('Failed to connect') print('Failed to connect')
""" """
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = StateCache(
await self.session.get_update_state(0), self._log)
if not await self._sender.connect(self._connection( if not await self._sender.connect(self._connection(
self.session.server_address, self.session.server_address,
self.session.port, self.session.port,
@ -534,7 +536,7 @@ class TelegramBaseClient(abc.ABC):
return return
self.session.auth_key = self._sender.auth_key self.session.auth_key = self._sender.auth_key
self.session.save() await self.session.save()
self._init_request.query = functions.help.GetConfigRequest() self._init_request.query = functions.help.GetConfigRequest()
@ -644,7 +646,7 @@ class TelegramBaseClient(abc.ABC):
pts, date = self._state_cache[None] pts, date = self._state_cache[None]
if pts and date: if pts and date:
self.session.set_update_state(0, types.updates.State( await self.session.set_update_state(0, types.updates.State(
pts=pts, pts=pts,
qts=0, qts=0,
date=date, date=date,
@ -652,7 +654,7 @@ class TelegramBaseClient(abc.ABC):
unread_count=0 unread_count=0
)) ))
self.session.close() await self.session.close()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
""" """
@ -677,17 +679,17 @@ class TelegramBaseClient(abc.ABC):
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
self.session.save() await self.session.save()
await self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
def _auth_key_callback(self: 'TelegramClient', auth_key): async def _auth_key_callback(self: 'TelegramClient', auth_key):
""" """
Callback from the sender whenever it needed to generate a Callback from the sender whenever it needed to generate a
new authorization key. This means we are not authorized. new authorization key. This means we are not authorized.
""" """
self.session.auth_key = auth_key self.session.auth_key = auth_key
self.session.save() await self.session.save()
# endregion # endregion
@ -812,7 +814,7 @@ class TelegramBaseClient(abc.ABC):
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = self.session.clone()
await session.set_dc(dc.id, dc.ip_address, dc.port) session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')

View File

@ -255,7 +255,7 @@ class UpdateMethods:
state = d.intermediate_state state = d.intermediate_state
pts, date = state.pts, state.date pts, date = state.pts, state.date
self._handle_update(types.Updates( await self._handle_update(types.Updates(
users=d.users, users=d.users,
chats=d.chats, chats=d.chats,
date=state.date, date=state.date,
@ -300,8 +300,8 @@ class UpdateMethods:
# It is important to not make _handle_update async because we rely on # It is important to not make _handle_update async because we rely on
# the order that the updates arrive in to update the pts and date to # the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async. # be always-increasing. There is also no need to make this async.
def _handle_update(self: 'TelegramClient', update): async def _handle_update(self: 'TelegramClient', update):
self.session.process_entities(update) await self.session.process_entities(update)
self._entity_cache.add(update) self._entity_cache.add(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)): if isinstance(update, (types.Updates, types.UpdatesCombined)):
@ -372,7 +372,7 @@ class UpdateMethods:
# inserted because this is a rather expensive operation # inserted because this is a rather expensive operation
# (default's sqlite3 takes ~0.1s to commit changes). Do # (default's sqlite3 takes ~0.1s to commit changes). Do
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
self.session.save() await self.session.save()
# We need to send some content-related request at least hourly # We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will # for Telegram to keep delivering updates, otherwise they will

View File

@ -71,7 +71,7 @@ class UserMethods:
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) await self.session.process_entities(result)
self._entity_cache.add(result) self._entity_cache.add(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
@ -82,7 +82,7 @@ class UserMethods:
return results return results
else: else:
result = await future result = await future
self.session.process_entities(result) await self.session.process_entities(result)
self._entity_cache.add(result) self._entity_cache.add(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
@ -427,7 +427,7 @@ class UserMethods:
# No InputPeer, cached peer, or known string. Fetch from disk cache # No InputPeer, cached peer, or known string. Fetch from disk cache
try: try:
return self.session.get_input_entity(peer) return await self.session.get_input_entity(peer)
except ValueError: except ValueError:
pass pass
@ -567,7 +567,7 @@ class UserMethods:
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( return await self.get_entity(
self.session.get_input_entity(string)) await self.session.get_input_entity(string))
except ValueError: except ValueError:
pass pass

View File

@ -295,7 +295,7 @@ class MTProtoSender:
# notify whenever we change it. This is crucial when we # notify whenever we change it. This is crucial when we
# switch to different data centers. # switch to different data centers.
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(self.auth_key) await self._auth_key_callback(self.auth_key)
self._log.debug('auth_key generation success!') self._log.debug('auth_key generation success!')
return True return True
@ -380,7 +380,7 @@ class MTProtoSender:
self._log.info('Broken authorization key; resetting') self._log.info('Broken authorization key; resetting')
self.auth_key.key = None self.auth_key.key = None
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(None) await self._auth_key_callback(None)
ok = False ok = False
break break
@ -524,7 +524,7 @@ class MTProtoSender:
self._log.info('Broken authorization key; resetting') self._log.info('Broken authorization key; resetting')
self.auth_key.key = None self.auth_key.key = None
if self._auth_key_callback: if self._auth_key_callback:
self._auth_key_callback(None) await self._auth_key_callback(None)
await self._disconnect(error=e) await self._disconnect(error=e)
else: else:
@ -653,7 +653,7 @@ class MTProtoSender:
self._log.debug('Handling update %s', message.obj.__class__.__name__) self._log.debug('Handling update %s', message.obj.__class__.__name__)
if self._update_callback: if self._update_callback:
self._update_callback(message.obj) await self._update_callback(message.obj)
async def _handle_pong(self, message): async def _handle_pong(self, message):
""" """

View File

@ -79,7 +79,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_update_state(self, entity_id): async def get_update_state(self, entity_id):
""" """
Returns the ``UpdateState`` associated with the given `entity_id`. Returns the ``UpdateState`` associated with the given `entity_id`.
If the `entity_id` is 0, it should return the ``UpdateState`` for If the `entity_id` is 0, it should return the ``UpdateState`` for
@ -89,7 +89,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def set_update_state(self, entity_id, state): async def set_update_state(self, entity_id, state):
""" """
Sets the given ``UpdateState`` for the specified `entity_id`, which Sets the given ``UpdateState`` for the specified `entity_id`, which
should be 0 if the ``UpdateState`` is the "general" state (and not should be 0 if the ``UpdateState`` is the "general" state (and not
@ -98,14 +98,14 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def close(self): async def close(self):
""" """
Called on client disconnection. Should be used to Called on client disconnection. Should be used to
free any used resources. Can be left empty if none. free any used resources. Can be left empty if none.
""" """
@abstractmethod @abstractmethod
def save(self): async def save(self):
""" """
Called whenever important properties change. It should Called whenever important properties change. It should
make persist the relevant session information to disk. make persist the relevant session information to disk.
@ -113,22 +113,15 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete(self): async def delete(self):
""" """
Called upon client.log_out(). Should delete the stored Called upon client.log_out(). Should delete the stored
information from disk since it's not valid anymore. information from disk since it's not valid anymore.
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def list_sessions(cls):
"""
Lists available sessions. Not used by the library itself.
"""
return []
@abstractmethod @abstractmethod
def process_entities(self, tlo): async def process_entities(self, tlo):
""" """
Processes the input ``TLObject`` or ``list`` and saves Processes the input ``TLObject`` or ``list`` and saves
whatever information is relevant (e.g., ID or access hash). whatever information is relevant (e.g., ID or access hash).
@ -136,7 +129,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_input_entity(self, key): async def get_input_entity(self, key):
""" """
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``).
The library uses this method whenever an ``InputPeer`` is needed The library uses this method whenever an ``InputPeer`` is needed
@ -144,24 +137,3 @@ class Session(ABC):
to use a cached username to avoid extra RPC). to use a cached username to avoid extra RPC).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
"""
Caches the given file information persistently, so that it
doesn't need to be re-uploaded in case the file is used again.
The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``,
both with an ``.id`` and ``.access_hash`` attributes.
"""
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
"""
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size``
match an existing saved record. The class will either be an
``InputPhoto`` or ``InputDocument``, both with two parameters
``id`` and ``access_hash`` in that order.
"""
raise NotImplementedError

View File

@ -71,19 +71,19 @@ class MemorySession(Session):
def takeout_id(self, value): def takeout_id(self, value):
self._takeout_id = value self._takeout_id = value
def get_update_state(self, entity_id): async def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)
def set_update_state(self, entity_id, state): async def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state self._update_states[entity_id] = state
def close(self): async def close(self):
pass pass
def save(self): async def save(self):
pass pass
def delete(self): async def delete(self):
pass pass
@staticmethod @staticmethod
@ -144,31 +144,31 @@ class MemorySession(Session):
rows.append(row) rows.append(row)
return rows return rows
def process_entities(self, tlo): async def process_entities(self, tlo):
self._entities |= set(self._entities_to_rows(tlo)) self._entities |= set(self._entities_to_rows(tlo))
def get_entity_rows_by_phone(self, phone): async def get_entity_rows_by_phone(self, phone):
try: try:
return next((id, hash) for id, hash, _, found_phone, _ return next((id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone) in self._entities if found_phone == phone)
except StopIteration: except StopIteration:
pass pass
def get_entity_rows_by_username(self, username): async def get_entity_rows_by_username(self, username):
try: try:
return next((id, hash) for id, hash, found_username, _, _ return next((id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username) in self._entities if found_username == username)
except StopIteration: except StopIteration:
pass pass
def get_entity_rows_by_name(self, name): async def get_entity_rows_by_name(self, name):
try: try:
return next((id, hash) for id, hash, _, _, found_name return next((id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name) in self._entities if found_name == name)
except StopIteration: except StopIteration:
pass pass
def get_entity_rows_by_id(self, id, exact=True): async def get_entity_rows_by_id(self, id, exact=True):
try: try:
if exact: if exact:
return next((id, hash) for found_id, hash, _, _, _ return next((id, hash) for found_id, hash, _, _, _
@ -184,7 +184,7 @@ class MemorySession(Session):
except StopIteration: except StopIteration:
pass pass
def get_input_entity(self, key): async def get_input_entity(self, key):
try: try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
@ -204,21 +204,21 @@ class MemorySession(Session):
if isinstance(key, str): if isinstance(key, str):
phone = utils.parse_phone(key) phone = utils.parse_phone(key)
if phone: if phone:
result = self.get_entity_rows_by_phone(phone) result = await self.get_entity_rows_by_phone(phone)
else: else:
username, invite = utils.parse_username(key) username, invite = utils.parse_username(key)
if username and not invite: if username and not invite:
result = self.get_entity_rows_by_username(username) result = await self.get_entity_rows_by_username(username)
else: else:
tup = utils.resolve_invite_link(key)[1] tup = utils.resolve_invite_link(key)[1]
if tup: if tup:
result = self.get_entity_rows_by_id(tup, exact=False) result = await self.get_entity_rows_by_id(tup, exact=False)
elif isinstance(key, int): elif isinstance(key, int):
result = self.get_entity_rows_by_id(key, exact) result = await self.get_entity_rows_by_id(key, exact)
if not result and isinstance(key, str): if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key) result = await self.get_entity_rows_by_name(key)
if result: if result:
entity_id, entity_hash = result # unpack resulting tuple entity_id, entity_hash = result # unpack resulting tuple
@ -233,14 +233,14 @@ class MemorySession(Session):
else: else:
raise ValueError('Could not find input entity with key ', key) raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance): async def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(type(instance))) key = (md5_digest, file_size, _SentFileType.from_type(type(instance)))
value = (instance.id, instance.access_hash) value = (instance.id, instance.access_hash)
self._files[key] = value self._files[key] = value
def get_file(self, md5_digest, file_size, cls): async def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls)) key = (md5_digest, file_size, _SentFileType.from_type(cls))
try: try:
return cls(*self._files[key]) return cls(*self._files[key])