From 15ef302428c9d31865950f74287db947be6cdd4c Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sun, 10 Jun 2018 21:30:16 +0200 Subject: [PATCH] Implement _switch_dc/fix missing first request --- telethon/client/telegrambaseclient.py | 57 ++++++++++++--------------- telethon/network/mtprotosender.py | 11 +++--- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 2e76cddd..f57afd8b 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -96,8 +96,9 @@ class TelegramBaseClient(abc.ABC): # Current TelegramClient version __version__ = version.__version__ - # Server configuration (with .dc_options) + # Cached server configuration (with .dc_options), can be "global" _config = None + _cdn_config = None # region Initialization @@ -219,9 +220,14 @@ class TelegramBaseClient(abc.ABC): """ Connects to Telegram. """ + had_auth = self.session.auth_key is not None await self._sender.connect( self.session.server_address, self.session.port) + if not had_auth: + self.session.auth_key = self._sender.state.auth_key + self.session.save() + def is_connected(self): """ Returns ``True`` if the user has connected. @@ -237,22 +243,20 @@ class TelegramBaseClient(abc.ABC): # self.session.set_update_state(0, self.updates.get_update_state(0)) self.session.close() - def _switch_dc(self, new_dc): + async def _switch_dc(self, new_dc): """ Permanently switches the current connection to the new data center. """ - # TODO Implement - raise NotImplementedError - dc = self._get_dc(new_dc) - __log__.info('Reconnecting to new data center %s', dc) + __log__.info('Reconnecting to new data center %s', new_dc) + dc = await self._get_dc(new_dc) self.session.set_dc(dc.id, dc.ip_address, dc.port) # auth_key's are associated with a server, which has now changed # so it's not valid anymore. Set to None to force recreating it. - self.session.auth_key = None + self.session.auth_key = self._sender.state.auth_key = None self.session.save() - self.disconnect() - return self.connect() + await self.disconnect() + return await self.connect() # endregion @@ -260,31 +264,20 @@ class TelegramBaseClient(abc.ABC): async def _get_dc(self, dc_id, cdn=False): """Gets the Data Center (DC) associated to 'dc_id'""" - if not TelegramBaseClient._config: - TelegramBaseClient._config =\ - await self(functions.help.GetConfigRequest()) + cls = self.__class__ + if not cls._config: + cls._config = await self(functions.help.GetConfigRequest()) - try: - if cdn: - # Ensure we have the latest keys for the CDNs - result = await self(functions.help.GetCdnConfigRequest()) - for pk in result.public_keys: - rsa.add_key(pk.public_key) + if cdn and not self._cdn_config: + cls._cdn_config = await self(functions.help.GetCdnConfigRequest()) + for pk in cls._cdn_config.public_keys: + rsa.add_key(pk.public_key) - return next( - dc for dc in TelegramBaseClient._config.dc_options - if dc.id == dc_id - and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn - ) - except StopIteration: - if not cdn: - raise - - # New configuration, perhaps a new CDN was added? - TelegramBaseClient._config =\ - await self(functions.help.GetConfigRequest()) - - return self._get_dc(dc_id, cdn=cdn) + return next( + dc for dc in cls._config.dc_options + if dc.id == dc_id + and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn + ) async def _get_exported_client(self, dc_id): """Creates and connects a new TelegramBareClient for the desired DC. diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 328ac434..eb865f7a 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -215,6 +215,7 @@ class MTProtoSender: __log__.debug('Connection success!') if self.state.auth_key is None: + self._is_first_query = True _last_error = SecurityError() plain = MTProtoPlainSender(self._connection) for retry in range(1, self._retries + 1): @@ -233,14 +234,13 @@ class MTProtoSender: __log__.debug('Starting send loop') self._send_loop_handle = asyncio.ensure_future(self._send_loop()) + __log__.debug('Starting receive loop') + self._recv_loop_handle = asyncio.ensure_future(self._recv_loop()) if self._is_first_query: __log__.debug('Running first query') self._is_first_query = False - async with self._send_lock: - self.send(self._first_query) + await self.send(self._first_query) - __log__.debug('Starting receive loop') - self._recv_loop_handle = asyncio.ensure_future(self._recv_loop()) __log__.info('Connection to {} complete!'.format(self._ip)) async def _reconnect(self): @@ -327,7 +327,8 @@ class MTProtoSender: else: self._send_queue.put_nowait(m) - __log__.debug('Outgoing messages sent!') + __log__.debug('Outgoing messages {} sent!' + .format(', '.join(str(m.msg_id) for m in messages))) async def _recv_loop(self): """