diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index eab49589..14dbe2ca 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -27,21 +27,31 @@ MAX_CHUNK_SIZE = 512 * 1024 # 2021-01-15, users reported that `errors.TimeoutError` can occur while downloading files. TIMED_OUT_SLEEP = 1 + +class _CdnRedirect(Exception): + def __init__(self, cdn_redirect=None): + self.cdn_redirect = cdn_redirect + + class _DirectDownloadIter(RequestIter): async def _init( - self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data - ): + self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data, cdn_redirect=None): self.request = functions.upload.GetFileRequest( - file, offset=offset, limit=request_size) - + file, offset=offset, limit=request_size) + self._client = self.client + self._cdn_redirect = cdn_redirect + if cdn_redirect is not None: + self.request = functions.upload.GetCdnFileRequest(cdn_redirect.file_token, offset=offset, limit=request_size) + self._client = await self.client._get_cdn_client(cdn_redirect) + self.total = file_size self._stride = stride self._chunk_size = chunk_size self._last_part = None self._msg_data = msg_data self._timed_out = False - - self._exported = dc_id and self.client.session.dc_id != dc_id + + self._exported = dc_id and self._client.session.dc_id != dc_id if not self._exported: # The used sender will also change if ``FileMigrateError`` occurs self._sender = self.client._sender @@ -73,10 +83,16 @@ class _DirectDownloadIter(RequestIter): async def _request(self): try: - result = await self.client._call(self._sender, self.request) + result = await self._client._call(self._sender, self.request) self._timed_out = False if isinstance(result, types.upload.FileCdnRedirect): - raise NotImplementedError # TODO Implement + if self.client._mb_entity_cache.self_bot: + raise ValueError('FileCdnRedirect but the GetCdnFileRequest API access for bot users is restricted. Try to change api_id to avoid FileCdnRedirect') + raise _CdnRedirect(result) + if isinstance(result, types.upload.CdnFileReuploadNeeded): + await self.client._call(self.client._sender, functions.upload.ReuploadCdnFileRequest(file_token=self._cdn_redirect.file_token, request_token=result.request_token)) + result = await self._client._call(self._sender, self.request) + return result.bytes else: return result.bytes @@ -516,7 +532,9 @@ class DownloadMethods: dc_id: int = None, key: bytes = None, iv: bytes = None, - msg_data: tuple = None) -> typing.Optional[bytes]: + msg_data: tuple = None, + cdn_redirect: types.upload.FileCdnRedirect = None + ) -> typing.Optional[bytes]: if not part_size_kb: if not file_size: part_size_kb = 64 # Reasonable default @@ -543,7 +561,7 @@ class DownloadMethods: try: async for chunk in self._iter_download( - input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data): + input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data, cdn_redirect=cdn_redirect): if iv and key: chunk = AES.decrypt_ige(chunk, key, iv) r = f.write(chunk) @@ -561,6 +579,20 @@ class DownloadMethods: if in_memory: return f.getvalue() + except _CdnRedirect as e: + self._log[__name__].info('FileCdnRedirect to CDN data center %s', e.cdn_redirect.dc_id) + return await self._download_file( + input_location=input_location, + file=file, + part_size_kb=part_size_kb, + file_size=file_size, + progress_callback=progress_callback, + dc_id=e.cdn_redirect.dc_id, + key=e.cdn_redirect.encryption_key, + iv=e.cdn_redirect.encryption_iv, + msg_data=msg_data, + cdn_redirect=e.cdn_redirect + ) finally: if isinstance(file, str) or in_memory: f.close() @@ -682,7 +714,8 @@ class DownloadMethods: request_size: int = MAX_CHUNK_SIZE, file_size: int = None, dc_id: int = None, - msg_data: tuple = None + msg_data: tuple = None, + cdn_redirect: types.upload.FileCdnRedirect = None ): info = utils._get_file_info(file) if info.dc_id is not None: @@ -733,6 +766,7 @@ class DownloadMethods: request_size=request_size, file_size=file_size, msg_data=msg_data, + cdn_redirect=cdn_redirect ) # endregion diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 6f469f41..df754d9b 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -401,6 +401,7 @@ class TelegramBaseClient(abc.ABC): # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders self._borrowed_senders = {} self._borrow_sender_lock = asyncio.Lock() + self._exported_sessions = {} self._loop = None # only used as a sanity check self._updates_error = None @@ -785,7 +786,8 @@ class TelegramBaseClient(abc.ABC): if cdn and not self._cdn_config: cls._cdn_config = await self(functions.help.GetCdnConfigRequest()) for pk in cls._cdn_config.public_keys: - rsa.add_key(pk.public_key) + if pk.dc_id == dc_id: + rsa.add_key(pk.public_key, old=False) try: return next( @@ -798,10 +800,13 @@ class TelegramBaseClient(abc.ABC): 'Failed to get DC %s (cdn = %s) with use_ipv6 = %s; retrying ignoring IPv6 check', dc_id, cdn, self._use_ipv6 ) - return next( - dc for dc in cls._config.dc_options - if dc.id == dc_id and bool(dc.cdn) == cdn - ) + try: + return next( + dc for dc in cls._config.dc_options + if dc.id == dc_id and bool(dc.cdn) == cdn + ) + except StopIteration: + raise ValueError(f'Failed to get DC {dc_id} (cdn = {cdn})') async def _create_exported_sender(self: 'TelegramClient', dc_id): """ @@ -889,8 +894,6 @@ class TelegramBaseClient(abc.ABC): async def _get_cdn_client(self: 'TelegramClient', cdn_redirect): """Similar to ._borrow_exported_client, but for CDNs""" - # TODO Implement - raise NotImplementedError session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) @@ -899,18 +902,22 @@ class TelegramBaseClient(abc.ABC): self._exported_sessions[cdn_redirect.dc_id] = session self._log[__name__].info('Creating new CDN client') - client = TelegramBaseClient( + client = self.__class__( session, self.api_id, self.api_hash, - proxy=self._sender.connection.conn.proxy, - timeout=self._sender.connection.get_timeout() + proxy=self._proxy, + timeout=self._timeout, + loop=self.loop ) - # This will make use of the new RSA keys for this specific CDN. - # - # We won't be calling GetConfigRequest because it's only called - # when needed by ._get_dc, and also it's static so it's likely - # set already. Avoid invoking non-CDN methods by not syncing updates. - client.connect(_sync_updates=False) + session.auth_key = self._sender.auth_key + await client._sender.connect(self._connection( + session.server_address, + session.port, + session.dc_id, + loggers=self._log, + proxy=self._proxy, + local_addr=self._local_addr + )) return client # endregion