Cache session_state and all_dcs right after connect

This commit is contained in:
Lonami Exo 2021-09-19 17:51:05 +02:00
parent 93dd2a186a
commit 545e9d69ce
5 changed files with 39 additions and 97 deletions

View File

@ -39,10 +39,7 @@ class _DirectDownloadIter(requestiter.RequestIter):
self._msg_data = msg_data self._msg_data = msg_data
self._timed_out = False self._timed_out = False
# TODO should cache current session state self._exported = dc_id and self.client._session_state.dc_id != dc_id
state = await self.client.session.get_state()
self._exported = dc_id and state.dc_id != dc_id
if not self._exported: if not self._exported:
# The used sender will also change if ``FileMigrateError`` occurs # The used sender will also change if ``FileMigrateError`` occurs
self._sender = self.client._sender self._sender = self.client._sender

View File

@ -597,10 +597,7 @@ async def edit_message(
) )
# Invoke `messages.editInlineBotMessage` from the right datacenter. # Invoke `messages.editInlineBotMessage` from the right datacenter.
# Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing. # Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing.
# TODO should cache current session state exported = self._session_state.dc_id != entity.dc_id
state = await self.session.get_state()
exported = state.dc_id != entity.dc_id
if exported: if exported:
try: try:
sender = await self._borrow_exported_sender(entity.dc_id) sender = await self._borrow_exported_sender(entity.dc_id)

View File

@ -142,6 +142,11 @@ def init(
# TODO Session should probably return all cached # TODO Session should probably return all cached
# info of entities, not just the input versions # info of entities, not just the input versions
self.session = session self.session = session
# Cache session data for convenient access
self._session_state = None
self._all_dcs = None
self._entity_cache = entitycache.EntityCache() self._entity_cache = entitycache.EntityCache()
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -296,10 +301,19 @@ def set_flood_sleep_threshold(self, value):
async def connect(self: 'TelegramClient') -> None: async def connect(self: 'TelegramClient') -> None:
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()} self._all_dcs = {dc.id: dc for dc in await self.session.get_all_dc()}
state = await self.session.get_state() self._session_state = await self.session.get_state() or SessionState(
user_id=0,
dc_id=DEFAULT_DC_ID,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
dc = all_dc.get(state.dc_id) if state else None dc = self._all_dcs.get(self._session_state.dc_id)
if dc is None: if dc is None:
dc = DataCenter( dc = DataCenter(
id=DEFAULT_DC_ID, id=DEFAULT_DC_ID,
@ -308,11 +322,11 @@ async def connect(self: 'TelegramClient') -> None:
port=DEFAULT_PORT, port=DEFAULT_PORT,
auth=b'', auth=b'',
) )
all_dc[dc.id] = dc self._all_dcs[dc.id] = dc
# Update state (for catching up after a disconnection) # Update state (for catching up after a disconnection)
# TODO Get state from channels too # TODO Get state from channels too
self._state_cache = statecache.StateCache(state, self._log) self._state_cache = statecache.StateCache(self._session_state, self._log)
# Use known key, if any # Use known key, if any
self._sender.auth_key.key = dc.auth self._sender.auth_key.key = dc.auth
@ -345,18 +359,18 @@ async def connect(self: 'TelegramClient') -> None:
continue continue
ip = int(ipaddress.ip_address(dc.ip_address)) ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dc: if dc.id in self._all_dcs:
all_dc[dc.id].port = dc.port self._all_dcs[dc.id].port = dc.port
if dc.ipv6: if dc.ipv6:
all_dc[dc.id].ipv6 = ip self._all_dcs[dc.id].ipv6 = ip
else: else:
all_dc[dc.id].ipv4 = ip self._all_dcs[dc.id].ipv4 = ip
elif dc.ipv6: elif dc.ipv6:
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'') self._all_dcs[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else: else:
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'') self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dc.values(): for dc in self._all_dcs.values():
await self.session.insert_dc(dc) await self.session.insert_dc(dc)
await self.session.save() await self.session.save()
@ -419,11 +433,10 @@ async def _disconnect_coro(self: 'TelegramClient'):
pts, date = self._state_cache[None] pts, date = self._state_cache[None]
if pts and date: if pts and date:
state = await self.session.get_state() if self._session_state:
if state: self._session_state.pts = pts
state.pts = pts self._session_state.date = date
state.date = date await self.session.set_state(self._session_state)
await self.session.set_state(state)
await self.session.save() await self.session.save()
async def _disconnect(self: 'TelegramClient'): async def _disconnect(self: 'TelegramClient'):
@ -442,76 +455,14 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
Permanently switches the current connection to the new data center. Permanently switches the current connection to the new data center.
""" """
self._log[__name__].info('Reconnecting to new data center %s', new_dc) self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await _refresh_and_get_dc(self, new_dc)
state = await self.session.get_state() self._session_state.dc_id = new_dc
if state is None: await self.session.set_state(self._session_state)
state = SessionState(
user_id=0,
dc_id=dc.id,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
else:
state.dc_id = dc.id
await self.session.set_state(dc.id)
await self.session.save() await self.session.save()
await _disconnect(self) await _disconnect(self)
return await self.connect() return await self.connect()
async def _refresh_and_get_dc(self: 'TelegramClient', dc_id):
"""
Gets the Data Center (DC) associated to `dc_id`.
Also take this opportunity to refresh the addresses stored in the session if needed.
"""
cls = self.__class__
if not cls._config:
cls._config = await self(_tl.fn.help.GetConfig())
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
for dc in cls._config.dc_options:
if dc.media_only or dc.tcpo_only or dc.cdn:
continue
ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dc:
all_dc[dc.id].port = dc.port
if dc.ipv6:
all_dc[dc.id].ipv6 = ip
else:
all_dc[dc.id].ipv4 = ip
elif dc.ipv6:
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else:
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dc.values():
await self.session.insert_dc(dc)
await self.session.save()
try:
return next(
dc for dc in cls._config.dc_options
if dc.id == dc_id
and bool(dc.ipv6) == self._use_ipv6 and not dc.cdn
)
except StopIteration:
self._log[__name__].warning(
'Failed to get DC %swith use_ipv6 = %s; retrying ignoring IPv6 check',
dc_id, self._use_ipv6
)
return next(
dc for dc in cls._config.dc_options
if dc.id == dc_id and not dc.cdn
)
async def _create_exported_sender(self: 'TelegramClient', dc_id): async def _create_exported_sender(self: 'TelegramClient', dc_id):
""" """
Creates a new exported `MTProtoSender` for the given `dc_id` and Creates a new exported `MTProtoSender` for the given `dc_id` and
@ -519,14 +470,14 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id):
""" """
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt # Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization # for clearly showing how to export the authorization
dc = await _refresh_and_get_dc(self, dc_id) dc = self._all_dcs[dc_id]
# Can't reuse self._sender._connection as it has its own seqno. # Can't reuse self._sender._connection as it has its own seqno.
# #
# If one were to do that, Telegram would reset the connection # If one were to do that, Telegram would reset the connection
# with no further clues. # with no further clues.
sender = MTProtoSender(loggers=self._log) sender = MTProtoSender(loggers=self._log)
await sender.connect(self._connection( await sender.connect(self._connection(
dc.ip_address, str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
dc.port, dc.port,
dc.id, dc.id,
loggers=self._log, loggers=self._log,
@ -559,9 +510,9 @@ async def _borrow_exported_sender(self: 'TelegramClient', dc_id):
self._borrowed_senders[dc_id] = (state, sender) self._borrowed_senders[dc_id] = (state, sender)
elif state.need_connect(): elif state.need_connect():
dc = await _refresh_and_get_dc(self, dc_id) dc = self._all_dcs[dc_id]
await sender.connect(self._connection( await sender.connect(self._connection(
dc.ip_address, str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
dc.port, dc.port,
dc.id, dc.id,
loggers=self._log, loggers=self._log,

View File

@ -2704,9 +2704,6 @@ class TelegramClient:
# Current TelegramClient version # Current TelegramClient version
__version__ = version.__version__ __version__ = version.__version__
# Cached server configuration (with .dc_options), can be "global"
_config = None
def __init__( def __init__(
self: 'TelegramClient', self: 'TelegramClient',
session: 'typing.Union[str, Session]', session: 'typing.Union[str, Session]',

View File

@ -36,7 +36,7 @@ class StateCache:
# each update in case they need to fetch missing entities. # each update in case they need to fetch missing entities.
self._logger = loggers[__name__] self._logger = loggers[__name__]
if initial: if initial:
self._pts_date = initial.pts, initial.date self._pts_date = initial.pts or None, initial.date or None
else: else:
self._pts_date = None, None self._pts_date = None, None