Cache session_state and all_dcs right after connect

This commit is contained in:
Lonami Exo 2021-09-19 17:51:05 +02:00
parent 93dd2a186a
commit 545e9d69ce
5 changed files with 39 additions and 97 deletions

View File

@ -39,10 +39,7 @@ class _DirectDownloadIter(requestiter.RequestIter):
self._msg_data = msg_data
self._timed_out = False
# TODO should cache current session state
state = await self.client.session.get_state()
self._exported = dc_id and state.dc_id != dc_id
self._exported = dc_id and self.client._session_state.dc_id != dc_id
if not self._exported:
# The used sender will also change if ``FileMigrateError`` occurs
self._sender = self.client._sender

View File

@ -597,10 +597,7 @@ async def edit_message(
)
# Invoke `messages.editInlineBotMessage` from the right datacenter.
# Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing.
# TODO should cache current session state
state = await self.session.get_state()
exported = state.dc_id != entity.dc_id
exported = self._session_state.dc_id != entity.dc_id
if exported:
try:
sender = await self._borrow_exported_sender(entity.dc_id)

View File

@ -142,6 +142,11 @@ def init(
# TODO Session should probably return all cached
# info of entities, not just the input versions
self.session = session
# Cache session data for convenient access
self._session_state = None
self._all_dcs = None
self._entity_cache = entitycache.EntityCache()
self.api_id = int(api_id)
self.api_hash = api_hash
@ -296,10 +301,19 @@ def set_flood_sleep_threshold(self, value):
async def connect(self: 'TelegramClient') -> None:
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
state = await self.session.get_state()
self._all_dcs = {dc.id: dc for dc in await self.session.get_all_dc()}
self._session_state = await self.session.get_state() or SessionState(
user_id=0,
dc_id=DEFAULT_DC_ID,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
dc = all_dc.get(state.dc_id) if state else None
dc = self._all_dcs.get(self._session_state.dc_id)
if dc is None:
dc = DataCenter(
id=DEFAULT_DC_ID,
@ -308,11 +322,11 @@ async def connect(self: 'TelegramClient') -> None:
port=DEFAULT_PORT,
auth=b'',
)
all_dc[dc.id] = dc
self._all_dcs[dc.id] = dc
# Update state (for catching up after a disconnection)
# TODO Get state from channels too
self._state_cache = statecache.StateCache(state, self._log)
self._state_cache = statecache.StateCache(self._session_state, self._log)
# Use known key, if any
self._sender.auth_key.key = dc.auth
@ -345,18 +359,18 @@ async def connect(self: 'TelegramClient') -> None:
continue
ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dc:
all_dc[dc.id].port = dc.port
if dc.id in self._all_dcs:
self._all_dcs[dc.id].port = dc.port
if dc.ipv6:
all_dc[dc.id].ipv6 = ip
self._all_dcs[dc.id].ipv6 = ip
else:
all_dc[dc.id].ipv4 = ip
self._all_dcs[dc.id].ipv4 = ip
elif dc.ipv6:
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
self._all_dcs[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else:
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dc.values():
for dc in self._all_dcs.values():
await self.session.insert_dc(dc)
await self.session.save()
@ -419,11 +433,10 @@ async def _disconnect_coro(self: 'TelegramClient'):
pts, date = self._state_cache[None]
if pts and date:
state = await self.session.get_state()
if state:
state.pts = pts
state.date = date
await self.session.set_state(state)
if self._session_state:
self._session_state.pts = pts
self._session_state.date = date
await self.session.set_state(self._session_state)
await self.session.save()
async def _disconnect(self: 'TelegramClient'):
@ -442,76 +455,14 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
Permanently switches the current connection to the new data center.
"""
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
dc = await _refresh_and_get_dc(self, new_dc)
state = await self.session.get_state()
if state is None:
state = SessionState(
user_id=0,
dc_id=dc.id,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
else:
state.dc_id = dc.id
await self.session.set_state(dc.id)
self._session_state.dc_id = new_dc
await self.session.set_state(self._session_state)
await self.session.save()
await _disconnect(self)
return await self.connect()
async def _refresh_and_get_dc(self: 'TelegramClient', dc_id):
"""
Gets the Data Center (DC) associated to `dc_id`.
Also take this opportunity to refresh the addresses stored in the session if needed.
"""
cls = self.__class__
if not cls._config:
cls._config = await self(_tl.fn.help.GetConfig())
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
for dc in cls._config.dc_options:
if dc.media_only or dc.tcpo_only or dc.cdn:
continue
ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dc:
all_dc[dc.id].port = dc.port
if dc.ipv6:
all_dc[dc.id].ipv6 = ip
else:
all_dc[dc.id].ipv4 = ip
elif dc.ipv6:
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else:
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dc.values():
await self.session.insert_dc(dc)
await self.session.save()
try:
return next(
dc for dc in cls._config.dc_options
if dc.id == dc_id
and bool(dc.ipv6) == self._use_ipv6 and not dc.cdn
)
except StopIteration:
self._log[__name__].warning(
'Failed to get DC %swith use_ipv6 = %s; retrying ignoring IPv6 check',
dc_id, self._use_ipv6
)
return next(
dc for dc in cls._config.dc_options
if dc.id == dc_id and not dc.cdn
)
async def _create_exported_sender(self: 'TelegramClient', dc_id):
"""
Creates a new exported `MTProtoSender` for the given `dc_id` and
@ -519,14 +470,14 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id):
"""
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization
dc = await _refresh_and_get_dc(self, dc_id)
dc = self._all_dcs[dc_id]
# Can't reuse self._sender._connection as it has its own seqno.
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(loggers=self._log)
await sender.connect(self._connection(
dc.ip_address,
str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
dc.port,
dc.id,
loggers=self._log,
@ -559,9 +510,9 @@ async def _borrow_exported_sender(self: 'TelegramClient', dc_id):
self._borrowed_senders[dc_id] = (state, sender)
elif state.need_connect():
dc = await _refresh_and_get_dc(self, dc_id)
dc = self._all_dcs[dc_id]
await sender.connect(self._connection(
dc.ip_address,
str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
dc.port,
dc.id,
loggers=self._log,

View File

@ -2704,9 +2704,6 @@ class TelegramClient:
# Current TelegramClient version
__version__ = version.__version__
# Cached server configuration (with .dc_options), can be "global"
_config = None
def __init__(
self: 'TelegramClient',
session: 'typing.Union[str, Session]',

View File

@ -36,7 +36,7 @@ class StateCache:
# each update in case they need to fetch missing entities.
self._logger = loggers[__name__]
if initial:
self._pts_date = initial.pts, initial.date
self._pts_date = initial.pts or None, initial.date or None
else:
self._pts_date = None, None