diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index eab49589..705a1abb 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -27,39 +27,49 @@ MAX_CHUNK_SIZE = 512 * 1024 # 2021-01-15, users reported that `errors.TimeoutError` can occur while downloading files. TIMED_OUT_SLEEP = 1 + +class _CdnRedirect(Exception): + def __init__(self, cdn_redirect=None): + self.cdn_redirect = cdn_redirect + + class _DirectDownloadIter(RequestIter): async def _init( - self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data - ): + self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data, cdn_redirect=None): self.request = functions.upload.GetFileRequest( - file, offset=offset, limit=request_size) - + file, offset=offset, limit=request_size) + self._client = self.client + self.cdn_redirect = cdn_redirect + if cdn_redirect is not None: + self.request = functions.upload.GetCdnFileRequest(cdn_redirect.file_token, offset=offset, limit=request_size) + self._client = await self.client._get_cdn_client(cdn_redirect) + self.total = file_size self._stride = stride self._chunk_size = chunk_size self._last_part = None self._msg_data = msg_data self._timed_out = False - - self._exported = dc_id and self.client.session.dc_id != dc_id + + self._exported = dc_id and self._client.session.dc_id != dc_id if not self._exported: # The used sender will also change if ``FileMigrateError`` occurs - self._sender = self.client._sender + self._sender = self._client._sender else: try: - self._sender = await self.client._borrow_exported_sender(dc_id) + self._sender = await self._client._borrow_exported_sender(dc_id) except errors.DcIdInvalidError: # Can't export a sender for the ID we are currently in - config = await self.client(functions.help.GetConfigRequest()) + config = await self._client(functions.help.GetConfigRequest()) for option in config.dc_options: - if option.ip_address == self.client.session.server_address: - self.client.session.set_dc( + if option.ip_address == self._client.session.server_address: + self._client.session.set_dc( option.id, option.ip_address, option.port) - self.client.session.save() + self._client.session.save() break # TODO Figure out why the session may have the wrong DC ID - self._sender = self.client._sender + self._sender = self._client._sender self._exported = False async def _load_next_chunk(self): @@ -73,10 +83,15 @@ class _DirectDownloadIter(RequestIter): async def _request(self): try: - result = await self.client._call(self._sender, self.request) + result = await self._client._call(self._sender, self.request) self._timed_out = False if isinstance(result, types.upload.FileCdnRedirect): - raise NotImplementedError # TODO Implement + # raise NotImplementedError # TODO Implement + raise _CdnRedirect(result) + if isinstance(result, types.upload.CdnFileReuploadNeeded): + result = await self._client._call(self._sender, functions.upload.reuploadCdnFile(file_token=self.cdn_redirect.file_token, request_token=result.request_token)) + result = await self._client._call(self._sender, self.request) + return result.bytes else: return result.bytes @@ -516,7 +531,9 @@ class DownloadMethods: dc_id: int = None, key: bytes = None, iv: bytes = None, - msg_data: tuple = None) -> typing.Optional[bytes]: + msg_data: tuple = None, + cdn_redirect: types.upload.FileCdnRedirect = None + ) -> typing.Optional[bytes]: if not part_size_kb: if not file_size: part_size_kb = 64 # Reasonable default @@ -543,7 +560,7 @@ class DownloadMethods: try: async for chunk in self._iter_download( - input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data): + input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data, cdn_redirect=cdn_redirect): if iv and key: chunk = AES.decrypt_ige(chunk, key, iv) r = f.write(chunk) @@ -561,6 +578,20 @@ class DownloadMethods: if in_memory: return f.getvalue() + except _CdnRedirect as e: + self._log[__name__].info('FileCdnRedirect') + return await self._download_file( + input_location=input_location, + file=file, + part_size_kb=part_size_kb, + file_size=file_size, + progress_callback=progress_callback, + dc_id=e.cdn_redirect.dc_id, + key=e.cdn_redirect.encryption_key, + iv=e.cdn_redirect.encryption_iv, + msg_data=msg_data, + cdn_redirect=e.cdn_redirect + ) finally: if isinstance(file, str) or in_memory: f.close() @@ -682,7 +713,8 @@ class DownloadMethods: request_size: int = MAX_CHUNK_SIZE, file_size: int = None, dc_id: int = None, - msg_data: tuple = None + msg_data: tuple = None, + cdn_redirect: types.upload.FileCdnRedirect = None ): info = utils._get_file_info(file) if info.dc_id is not None: @@ -733,6 +765,7 @@ class DownloadMethods: request_size=request_size, file_size=file_size, msg_data=msg_data, + cdn_redirect=cdn_redirect ) # endregion diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 6f469f41..6bb6aa70 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -401,6 +401,7 @@ class TelegramBaseClient(abc.ABC): # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders self._borrowed_senders = {} self._borrow_sender_lock = asyncio.Lock() + self._exported_sessions = {} self._loop = None # only used as a sanity check self._updates_error = None @@ -785,7 +786,8 @@ class TelegramBaseClient(abc.ABC): if cdn and not self._cdn_config: cls._cdn_config = await self(functions.help.GetCdnConfigRequest()) for pk in cls._cdn_config.public_keys: - rsa.add_key(pk.public_key) + if pk.dc_id == dc_id: + rsa.add_key(pk.public_key, old=False) try: return next( @@ -890,7 +892,8 @@ class TelegramBaseClient(abc.ABC): async def _get_cdn_client(self: 'TelegramClient', cdn_redirect): """Similar to ._borrow_exported_client, but for CDNs""" # TODO Implement - raise NotImplementedError + # raise NotImplementedError + from .telegramclient import TelegramClient session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) @@ -899,10 +902,11 @@ class TelegramBaseClient(abc.ABC): self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') - client = TelegramBaseClient( + client = TelegramClient( session, self.api_id, self.api_hash, - proxy=self._sender.connection.conn.proxy, - timeout=self._sender.connection.get_timeout() + proxy=self._proxy, + timeout=self._timeout, + loop=self.loop ) # This will make use of the new RSA keys for this specific CDN. @@ -910,7 +914,12 @@ class TelegramBaseClient(abc.ABC): # We won't be calling GetConfigRequest because it's only called # when needed by ._get_dc, and also it's static so it's likely # set already. Avoid invoking non-CDN methods by not syncing updates. - client.connect(_sync_updates=False) + + self_id = self._mb_entity_cache.self_id + self_user = self.session.get_input_entity(self_id) + client._mb_entity_cache.set_self_user(self_id, True, self_user.access_hash) + + await client.start() return client # endregion