From 545e9d69ce3734a1285b84fabed6a136f7e3215f Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sun, 19 Sep 2021 17:51:05 +0200 Subject: [PATCH] Cache session_state and all_dcs right after connect --- telethon/_client/downloads.py | 5 +- telethon/_client/messages.py | 5 +- telethon/_client/telegrambaseclient.py | 121 ++++++++----------------- telethon/_client/telegramclient.py | 3 - telethon/_misc/statecache.py | 2 +- 5 files changed, 39 insertions(+), 97 deletions(-) diff --git a/telethon/_client/downloads.py b/telethon/_client/downloads.py index 99932cdc..50a383dd 100644 --- a/telethon/_client/downloads.py +++ b/telethon/_client/downloads.py @@ -39,10 +39,7 @@ class _DirectDownloadIter(requestiter.RequestIter): self._msg_data = msg_data self._timed_out = False - # TODO should cache current session state - state = await self.client.session.get_state() - - self._exported = dc_id and state.dc_id != dc_id + self._exported = dc_id and self.client._session_state.dc_id != dc_id if not self._exported: # The used sender will also change if ``FileMigrateError`` occurs self._sender = self.client._sender diff --git a/telethon/_client/messages.py b/telethon/_client/messages.py index 8305f60c..5b7ca110 100644 --- a/telethon/_client/messages.py +++ b/telethon/_client/messages.py @@ -597,10 +597,7 @@ async def edit_message( ) # Invoke `messages.editInlineBotMessage` from the right datacenter. # Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing. - # TODO should cache current session state - state = await self.session.get_state() - - exported = state.dc_id != entity.dc_id + exported = self._session_state.dc_id != entity.dc_id if exported: try: sender = await self._borrow_exported_sender(entity.dc_id) diff --git a/telethon/_client/telegrambaseclient.py b/telethon/_client/telegrambaseclient.py index 5afd0328..960b074f 100644 --- a/telethon/_client/telegrambaseclient.py +++ b/telethon/_client/telegrambaseclient.py @@ -142,6 +142,11 @@ def init( # TODO Session should probably return all cached # info of entities, not just the input versions self.session = session + + # Cache session data for convenient access + self._session_state = None + self._all_dcs = None + self._entity_cache = entitycache.EntityCache() self.api_id = int(api_id) self.api_hash = api_hash @@ -296,10 +301,19 @@ def set_flood_sleep_threshold(self, value): async def connect(self: 'TelegramClient') -> None: - all_dc = {dc.id: dc for dc in await self.session.get_all_dc()} - state = await self.session.get_state() + self._all_dcs = {dc.id: dc for dc in await self.session.get_all_dc()} + self._session_state = await self.session.get_state() or SessionState( + user_id=0, + dc_id=DEFAULT_DC_ID, + bot=False, + pts=0, + qts=0, + date=0, + seq=0, + takeout_id=None, + ) - dc = all_dc.get(state.dc_id) if state else None + dc = self._all_dcs.get(self._session_state.dc_id) if dc is None: dc = DataCenter( id=DEFAULT_DC_ID, @@ -308,11 +322,11 @@ async def connect(self: 'TelegramClient') -> None: port=DEFAULT_PORT, auth=b'', ) - all_dc[dc.id] = dc + self._all_dcs[dc.id] = dc # Update state (for catching up after a disconnection) # TODO Get state from channels too - self._state_cache = statecache.StateCache(state, self._log) + self._state_cache = statecache.StateCache(self._session_state, self._log) # Use known key, if any self._sender.auth_key.key = dc.auth @@ -345,18 +359,18 @@ async def connect(self: 'TelegramClient') -> None: continue ip = int(ipaddress.ip_address(dc.ip_address)) - if dc.id in all_dc: - all_dc[dc.id].port = dc.port + if dc.id in self._all_dcs: + self._all_dcs[dc.id].port = dc.port if dc.ipv6: - all_dc[dc.id].ipv6 = ip + self._all_dcs[dc.id].ipv6 = ip else: - all_dc[dc.id].ipv4 = ip + self._all_dcs[dc.id].ipv4 = ip elif dc.ipv6: - all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'') + self._all_dcs[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'') else: - all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'') + self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'') - for dc in all_dc.values(): + for dc in self._all_dcs.values(): await self.session.insert_dc(dc) await self.session.save() @@ -419,11 +433,10 @@ async def _disconnect_coro(self: 'TelegramClient'): pts, date = self._state_cache[None] if pts and date: - state = await self.session.get_state() - if state: - state.pts = pts - state.date = date - await self.session.set_state(state) + if self._session_state: + self._session_state.pts = pts + self._session_state.date = date + await self.session.set_state(self._session_state) await self.session.save() async def _disconnect(self: 'TelegramClient'): @@ -442,76 +455,14 @@ async def _switch_dc(self: 'TelegramClient', new_dc): Permanently switches the current connection to the new data center. """ self._log[__name__].info('Reconnecting to new data center %s', new_dc) - dc = await _refresh_and_get_dc(self, new_dc) - state = await self.session.get_state() - if state is None: - state = SessionState( - user_id=0, - dc_id=dc.id, - bot=False, - pts=0, - qts=0, - date=0, - seq=0, - takeout_id=None, - ) - else: - state.dc_id = dc.id - - await self.session.set_state(dc.id) + self._session_state.dc_id = new_dc + await self.session.set_state(self._session_state) await self.session.save() await _disconnect(self) return await self.connect() - -async def _refresh_and_get_dc(self: 'TelegramClient', dc_id): - """ - Gets the Data Center (DC) associated to `dc_id`. - - Also take this opportunity to refresh the addresses stored in the session if needed. - """ - cls = self.__class__ - if not cls._config: - cls._config = await self(_tl.fn.help.GetConfig()) - all_dc = {dc.id: dc for dc in await self.session.get_all_dc()} - for dc in cls._config.dc_options: - if dc.media_only or dc.tcpo_only or dc.cdn: - continue - - ip = int(ipaddress.ip_address(dc.ip_address)) - if dc.id in all_dc: - all_dc[dc.id].port = dc.port - if dc.ipv6: - all_dc[dc.id].ipv6 = ip - else: - all_dc[dc.id].ipv4 = ip - elif dc.ipv6: - all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'') - else: - all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'') - - for dc in all_dc.values(): - await self.session.insert_dc(dc) - await self.session.save() - - try: - return next( - dc for dc in cls._config.dc_options - if dc.id == dc_id - and bool(dc.ipv6) == self._use_ipv6 and not dc.cdn - ) - except StopIteration: - self._log[__name__].warning( - 'Failed to get DC %swith use_ipv6 = %s; retrying ignoring IPv6 check', - dc_id, self._use_ipv6 - ) - return next( - dc for dc in cls._config.dc_options - if dc.id == dc_id and not dc.cdn - ) - async def _create_exported_sender(self: 'TelegramClient', dc_id): """ Creates a new exported `MTProtoSender` for the given `dc_id` and @@ -519,14 +470,14 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id): """ # Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt # for clearly showing how to export the authorization - dc = await _refresh_and_get_dc(self, dc_id) + dc = self._all_dcs[dc_id] # Can't reuse self._sender._connection as it has its own seqno. # # If one were to do that, Telegram would reset the connection # with no further clues. sender = MTProtoSender(loggers=self._log) await sender.connect(self._connection( - dc.ip_address, + str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)), dc.port, dc.id, loggers=self._log, @@ -559,9 +510,9 @@ async def _borrow_exported_sender(self: 'TelegramClient', dc_id): self._borrowed_senders[dc_id] = (state, sender) elif state.need_connect(): - dc = await _refresh_and_get_dc(self, dc_id) + dc = self._all_dcs[dc_id] await sender.connect(self._connection( - dc.ip_address, + str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)), dc.port, dc.id, loggers=self._log, diff --git a/telethon/_client/telegramclient.py b/telethon/_client/telegramclient.py index 62511cd8..10db3557 100644 --- a/telethon/_client/telegramclient.py +++ b/telethon/_client/telegramclient.py @@ -2704,9 +2704,6 @@ class TelegramClient: # Current TelegramClient version __version__ = version.__version__ - # Cached server configuration (with .dc_options), can be "global" - _config = None - def __init__( self: 'TelegramClient', session: 'typing.Union[str, Session]', diff --git a/telethon/_misc/statecache.py b/telethon/_misc/statecache.py index 7f3ddf59..c1a6d7c9 100644 --- a/telethon/_misc/statecache.py +++ b/telethon/_misc/statecache.py @@ -36,7 +36,7 @@ class StateCache: # each update in case they need to fetch missing entities. self._logger = loggers[__name__] if initial: - self._pts_date = initial.pts, initial.date + self._pts_date = initial.pts or None, initial.date or None else: self._pts_date = None, None