mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 05:04:33 +03:00
Cache session_state and all_dcs right after connect
This commit is contained in:
parent
93dd2a186a
commit
545e9d69ce
|
@ -39,10 +39,7 @@ class _DirectDownloadIter(requestiter.RequestIter):
|
||||||
self._msg_data = msg_data
|
self._msg_data = msg_data
|
||||||
self._timed_out = False
|
self._timed_out = False
|
||||||
|
|
||||||
# TODO should cache current session state
|
self._exported = dc_id and self.client._session_state.dc_id != dc_id
|
||||||
state = await self.client.session.get_state()
|
|
||||||
|
|
||||||
self._exported = dc_id and state.dc_id != dc_id
|
|
||||||
if not self._exported:
|
if not self._exported:
|
||||||
# The used sender will also change if ``FileMigrateError`` occurs
|
# The used sender will also change if ``FileMigrateError`` occurs
|
||||||
self._sender = self.client._sender
|
self._sender = self.client._sender
|
||||||
|
|
|
@ -597,10 +597,7 @@ async def edit_message(
|
||||||
)
|
)
|
||||||
# Invoke `messages.editInlineBotMessage` from the right datacenter.
|
# Invoke `messages.editInlineBotMessage` from the right datacenter.
|
||||||
# Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing.
|
# Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing.
|
||||||
# TODO should cache current session state
|
exported = self._session_state.dc_id != entity.dc_id
|
||||||
state = await self.session.get_state()
|
|
||||||
|
|
||||||
exported = state.dc_id != entity.dc_id
|
|
||||||
if exported:
|
if exported:
|
||||||
try:
|
try:
|
||||||
sender = await self._borrow_exported_sender(entity.dc_id)
|
sender = await self._borrow_exported_sender(entity.dc_id)
|
||||||
|
|
|
@ -142,6 +142,11 @@ def init(
|
||||||
# TODO Session should probably return all cached
|
# TODO Session should probably return all cached
|
||||||
# info of entities, not just the input versions
|
# info of entities, not just the input versions
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
|
# Cache session data for convenient access
|
||||||
|
self._session_state = None
|
||||||
|
self._all_dcs = None
|
||||||
|
|
||||||
self._entity_cache = entitycache.EntityCache()
|
self._entity_cache = entitycache.EntityCache()
|
||||||
self.api_id = int(api_id)
|
self.api_id = int(api_id)
|
||||||
self.api_hash = api_hash
|
self.api_hash = api_hash
|
||||||
|
@ -296,10 +301,19 @@ def set_flood_sleep_threshold(self, value):
|
||||||
|
|
||||||
|
|
||||||
async def connect(self: 'TelegramClient') -> None:
|
async def connect(self: 'TelegramClient') -> None:
|
||||||
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
|
self._all_dcs = {dc.id: dc for dc in await self.session.get_all_dc()}
|
||||||
state = await self.session.get_state()
|
self._session_state = await self.session.get_state() or SessionState(
|
||||||
|
user_id=0,
|
||||||
|
dc_id=DEFAULT_DC_ID,
|
||||||
|
bot=False,
|
||||||
|
pts=0,
|
||||||
|
qts=0,
|
||||||
|
date=0,
|
||||||
|
seq=0,
|
||||||
|
takeout_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
dc = all_dc.get(state.dc_id) if state else None
|
dc = self._all_dcs.get(self._session_state.dc_id)
|
||||||
if dc is None:
|
if dc is None:
|
||||||
dc = DataCenter(
|
dc = DataCenter(
|
||||||
id=DEFAULT_DC_ID,
|
id=DEFAULT_DC_ID,
|
||||||
|
@ -308,11 +322,11 @@ async def connect(self: 'TelegramClient') -> None:
|
||||||
port=DEFAULT_PORT,
|
port=DEFAULT_PORT,
|
||||||
auth=b'',
|
auth=b'',
|
||||||
)
|
)
|
||||||
all_dc[dc.id] = dc
|
self._all_dcs[dc.id] = dc
|
||||||
|
|
||||||
# Update state (for catching up after a disconnection)
|
# Update state (for catching up after a disconnection)
|
||||||
# TODO Get state from channels too
|
# TODO Get state from channels too
|
||||||
self._state_cache = statecache.StateCache(state, self._log)
|
self._state_cache = statecache.StateCache(self._session_state, self._log)
|
||||||
|
|
||||||
# Use known key, if any
|
# Use known key, if any
|
||||||
self._sender.auth_key.key = dc.auth
|
self._sender.auth_key.key = dc.auth
|
||||||
|
@ -345,18 +359,18 @@ async def connect(self: 'TelegramClient') -> None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ip = int(ipaddress.ip_address(dc.ip_address))
|
ip = int(ipaddress.ip_address(dc.ip_address))
|
||||||
if dc.id in all_dc:
|
if dc.id in self._all_dcs:
|
||||||
all_dc[dc.id].port = dc.port
|
self._all_dcs[dc.id].port = dc.port
|
||||||
if dc.ipv6:
|
if dc.ipv6:
|
||||||
all_dc[dc.id].ipv6 = ip
|
self._all_dcs[dc.id].ipv6 = ip
|
||||||
else:
|
else:
|
||||||
all_dc[dc.id].ipv4 = ip
|
self._all_dcs[dc.id].ipv4 = ip
|
||||||
elif dc.ipv6:
|
elif dc.ipv6:
|
||||||
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
|
self._all_dcs[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
|
||||||
else:
|
else:
|
||||||
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
|
self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
|
||||||
|
|
||||||
for dc in all_dc.values():
|
for dc in self._all_dcs.values():
|
||||||
await self.session.insert_dc(dc)
|
await self.session.insert_dc(dc)
|
||||||
|
|
||||||
await self.session.save()
|
await self.session.save()
|
||||||
|
@ -419,11 +433,10 @@ async def _disconnect_coro(self: 'TelegramClient'):
|
||||||
|
|
||||||
pts, date = self._state_cache[None]
|
pts, date = self._state_cache[None]
|
||||||
if pts and date:
|
if pts and date:
|
||||||
state = await self.session.get_state()
|
if self._session_state:
|
||||||
if state:
|
self._session_state.pts = pts
|
||||||
state.pts = pts
|
self._session_state.date = date
|
||||||
state.date = date
|
await self.session.set_state(self._session_state)
|
||||||
await self.session.set_state(state)
|
|
||||||
await self.session.save()
|
await self.session.save()
|
||||||
|
|
||||||
async def _disconnect(self: 'TelegramClient'):
|
async def _disconnect(self: 'TelegramClient'):
|
||||||
|
@ -442,76 +455,14 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
|
||||||
Permanently switches the current connection to the new data center.
|
Permanently switches the current connection to the new data center.
|
||||||
"""
|
"""
|
||||||
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
|
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
|
||||||
dc = await _refresh_and_get_dc(self, new_dc)
|
|
||||||
|
|
||||||
state = await self.session.get_state()
|
self._session_state.dc_id = new_dc
|
||||||
if state is None:
|
await self.session.set_state(self._session_state)
|
||||||
state = SessionState(
|
|
||||||
user_id=0,
|
|
||||||
dc_id=dc.id,
|
|
||||||
bot=False,
|
|
||||||
pts=0,
|
|
||||||
qts=0,
|
|
||||||
date=0,
|
|
||||||
seq=0,
|
|
||||||
takeout_id=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
state.dc_id = dc.id
|
|
||||||
|
|
||||||
await self.session.set_state(dc.id)
|
|
||||||
await self.session.save()
|
await self.session.save()
|
||||||
|
|
||||||
await _disconnect(self)
|
await _disconnect(self)
|
||||||
return await self.connect()
|
return await self.connect()
|
||||||
|
|
||||||
|
|
||||||
async def _refresh_and_get_dc(self: 'TelegramClient', dc_id):
|
|
||||||
"""
|
|
||||||
Gets the Data Center (DC) associated to `dc_id`.
|
|
||||||
|
|
||||||
Also take this opportunity to refresh the addresses stored in the session if needed.
|
|
||||||
"""
|
|
||||||
cls = self.__class__
|
|
||||||
if not cls._config:
|
|
||||||
cls._config = await self(_tl.fn.help.GetConfig())
|
|
||||||
all_dc = {dc.id: dc for dc in await self.session.get_all_dc()}
|
|
||||||
for dc in cls._config.dc_options:
|
|
||||||
if dc.media_only or dc.tcpo_only or dc.cdn:
|
|
||||||
continue
|
|
||||||
|
|
||||||
ip = int(ipaddress.ip_address(dc.ip_address))
|
|
||||||
if dc.id in all_dc:
|
|
||||||
all_dc[dc.id].port = dc.port
|
|
||||||
if dc.ipv6:
|
|
||||||
all_dc[dc.id].ipv6 = ip
|
|
||||||
else:
|
|
||||||
all_dc[dc.id].ipv4 = ip
|
|
||||||
elif dc.ipv6:
|
|
||||||
all_dc[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
|
|
||||||
else:
|
|
||||||
all_dc[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
|
|
||||||
|
|
||||||
for dc in all_dc.values():
|
|
||||||
await self.session.insert_dc(dc)
|
|
||||||
await self.session.save()
|
|
||||||
|
|
||||||
try:
|
|
||||||
return next(
|
|
||||||
dc for dc in cls._config.dc_options
|
|
||||||
if dc.id == dc_id
|
|
||||||
and bool(dc.ipv6) == self._use_ipv6 and not dc.cdn
|
|
||||||
)
|
|
||||||
except StopIteration:
|
|
||||||
self._log[__name__].warning(
|
|
||||||
'Failed to get DC %swith use_ipv6 = %s; retrying ignoring IPv6 check',
|
|
||||||
dc_id, self._use_ipv6
|
|
||||||
)
|
|
||||||
return next(
|
|
||||||
dc for dc in cls._config.dc_options
|
|
||||||
if dc.id == dc_id and not dc.cdn
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _create_exported_sender(self: 'TelegramClient', dc_id):
|
async def _create_exported_sender(self: 'TelegramClient', dc_id):
|
||||||
"""
|
"""
|
||||||
Creates a new exported `MTProtoSender` for the given `dc_id` and
|
Creates a new exported `MTProtoSender` for the given `dc_id` and
|
||||||
|
@ -519,14 +470,14 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id):
|
||||||
"""
|
"""
|
||||||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||||
# for clearly showing how to export the authorization
|
# for clearly showing how to export the authorization
|
||||||
dc = await _refresh_and_get_dc(self, dc_id)
|
dc = self._all_dcs[dc_id]
|
||||||
# Can't reuse self._sender._connection as it has its own seqno.
|
# Can't reuse self._sender._connection as it has its own seqno.
|
||||||
#
|
#
|
||||||
# If one were to do that, Telegram would reset the connection
|
# If one were to do that, Telegram would reset the connection
|
||||||
# with no further clues.
|
# with no further clues.
|
||||||
sender = MTProtoSender(loggers=self._log)
|
sender = MTProtoSender(loggers=self._log)
|
||||||
await sender.connect(self._connection(
|
await sender.connect(self._connection(
|
||||||
dc.ip_address,
|
str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
|
||||||
dc.port,
|
dc.port,
|
||||||
dc.id,
|
dc.id,
|
||||||
loggers=self._log,
|
loggers=self._log,
|
||||||
|
@ -559,9 +510,9 @@ async def _borrow_exported_sender(self: 'TelegramClient', dc_id):
|
||||||
self._borrowed_senders[dc_id] = (state, sender)
|
self._borrowed_senders[dc_id] = (state, sender)
|
||||||
|
|
||||||
elif state.need_connect():
|
elif state.need_connect():
|
||||||
dc = await _refresh_and_get_dc(self, dc_id)
|
dc = self._all_dcs[dc_id]
|
||||||
await sender.connect(self._connection(
|
await sender.connect(self._connection(
|
||||||
dc.ip_address,
|
str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
|
||||||
dc.port,
|
dc.port,
|
||||||
dc.id,
|
dc.id,
|
||||||
loggers=self._log,
|
loggers=self._log,
|
||||||
|
|
|
@ -2704,9 +2704,6 @@ class TelegramClient:
|
||||||
# Current TelegramClient version
|
# Current TelegramClient version
|
||||||
__version__ = version.__version__
|
__version__ = version.__version__
|
||||||
|
|
||||||
# Cached server configuration (with .dc_options), can be "global"
|
|
||||||
_config = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self: 'TelegramClient',
|
self: 'TelegramClient',
|
||||||
session: 'typing.Union[str, Session]',
|
session: 'typing.Union[str, Session]',
|
||||||
|
|
|
@ -36,7 +36,7 @@ class StateCache:
|
||||||
# each update in case they need to fetch missing entities.
|
# each update in case they need to fetch missing entities.
|
||||||
self._logger = loggers[__name__]
|
self._logger = loggers[__name__]
|
||||||
if initial:
|
if initial:
|
||||||
self._pts_date = initial.pts, initial.date
|
self._pts_date = initial.pts or None, initial.date or None
|
||||||
else:
|
else:
|
||||||
self._pts_date = None, None
|
self._pts_date = None, None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user