Implement _switch_dc/fix missing first request

This commit is contained in:
Lonami Exo 2018-06-10 21:30:16 +02:00
parent 4a491e45ce
commit 15ef302428
2 changed files with 31 additions and 37 deletions

View File

@ -96,8 +96,9 @@ class TelegramBaseClient(abc.ABC):
# Current TelegramClient version # Current TelegramClient version
__version__ = version.__version__ __version__ = version.__version__
# Server configuration (with .dc_options) # Cached server configuration (with .dc_options), can be "global"
_config = None _config = None
_cdn_config = None
# region Initialization # region Initialization
@ -219,9 +220,14 @@ class TelegramBaseClient(abc.ABC):
""" """
Connects to Telegram. Connects to Telegram.
""" """
had_auth = self.session.auth_key is not None
await self._sender.connect( await self._sender.connect(
self.session.server_address, self.session.port) self.session.server_address, self.session.port)
if not had_auth:
self.session.auth_key = self._sender.state.auth_key
self.session.save()
def is_connected(self): def is_connected(self):
""" """
Returns ``True`` if the user has connected. Returns ``True`` if the user has connected.
@ -237,22 +243,20 @@ class TelegramBaseClient(abc.ABC):
# self.session.set_update_state(0, self.updates.get_update_state(0)) # self.session.set_update_state(0, self.updates.get_update_state(0))
self.session.close() self.session.close()
def _switch_dc(self, new_dc): async def _switch_dc(self, new_dc):
""" """
Permanently switches the current connection to the new data center. Permanently switches the current connection to the new data center.
""" """
# TODO Implement __log__.info('Reconnecting to new data center %s', new_dc)
raise NotImplementedError dc = await self._get_dc(new_dc)
dc = self._get_dc(new_dc)
__log__.info('Reconnecting to new data center %s', dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port) self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed # auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it. # so it's not valid anymore. Set to None to force recreating it.
self.session.auth_key = None self.session.auth_key = self._sender.state.auth_key = None
self.session.save() self.session.save()
self.disconnect() await self.disconnect()
return self.connect() return await self.connect()
# endregion # endregion
@ -260,31 +264,20 @@ class TelegramBaseClient(abc.ABC):
async def _get_dc(self, dc_id, cdn=False): async def _get_dc(self, dc_id, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'""" """Gets the Data Center (DC) associated to 'dc_id'"""
if not TelegramBaseClient._config: cls = self.__class__
TelegramBaseClient._config =\ if not cls._config:
await self(functions.help.GetConfigRequest()) cls._config = await self(functions.help.GetConfigRequest())
try: if cdn and not self._cdn_config:
if cdn: cls._cdn_config = await self(functions.help.GetCdnConfigRequest())
# Ensure we have the latest keys for the CDNs for pk in cls._cdn_config.public_keys:
result = await self(functions.help.GetCdnConfigRequest()) rsa.add_key(pk.public_key)
for pk in result.public_keys:
rsa.add_key(pk.public_key)
return next( return next(
dc for dc in TelegramBaseClient._config.dc_options dc for dc in cls._config.dc_options
if dc.id == dc_id if dc.id == dc_id
and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
) )
except StopIteration:
if not cdn:
raise
# New configuration, perhaps a new CDN was added?
TelegramBaseClient._config =\
await self(functions.help.GetConfigRequest())
return self._get_dc(dc_id, cdn=cdn)
async def _get_exported_client(self, dc_id): async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC. """Creates and connects a new TelegramBareClient for the desired DC.

View File

@ -215,6 +215,7 @@ class MTProtoSender:
__log__.debug('Connection success!') __log__.debug('Connection success!')
if self.state.auth_key is None: if self.state.auth_key is None:
self._is_first_query = True
_last_error = SecurityError() _last_error = SecurityError()
plain = MTProtoPlainSender(self._connection) plain = MTProtoPlainSender(self._connection)
for retry in range(1, self._retries + 1): for retry in range(1, self._retries + 1):
@ -233,14 +234,13 @@ class MTProtoSender:
__log__.debug('Starting send loop') __log__.debug('Starting send loop')
self._send_loop_handle = asyncio.ensure_future(self._send_loop()) self._send_loop_handle = asyncio.ensure_future(self._send_loop())
__log__.debug('Starting receive loop')
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
if self._is_first_query: if self._is_first_query:
__log__.debug('Running first query') __log__.debug('Running first query')
self._is_first_query = False self._is_first_query = False
async with self._send_lock: await self.send(self._first_query)
self.send(self._first_query)
__log__.debug('Starting receive loop')
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
__log__.info('Connection to {} complete!'.format(self._ip)) __log__.info('Connection to {} complete!'.format(self._ip))
async def _reconnect(self): async def _reconnect(self):
@ -327,7 +327,8 @@ class MTProtoSender:
else: else:
self._send_queue.put_nowait(m) self._send_queue.put_nowait(m)
__log__.debug('Outgoing messages sent!') __log__.debug('Outgoing messages {} sent!'
.format(', '.join(str(m.msg_id) for m in messages)))
async def _recv_loop(self): async def _recv_loop(self):
""" """