Support CDN downloads (#4420)

Closes #4327.
This commit is contained in:
灰白草 2024-08-08 02:25:35 +08:00 committed by GitHub
parent 946f803de7
commit 75408483ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 27 deletions

View File

@ -27,12 +27,22 @@ MAX_CHUNK_SIZE = 512 * 1024
# 2021-01-15, users reported that `errors.TimeoutError` can occur while downloading files. # 2021-01-15, users reported that `errors.TimeoutError` can occur while downloading files.
TIMED_OUT_SLEEP = 1 TIMED_OUT_SLEEP = 1
class _CdnRedirect(Exception):
def __init__(self, cdn_redirect=None):
self.cdn_redirect = cdn_redirect
class _DirectDownloadIter(RequestIter): class _DirectDownloadIter(RequestIter):
async def _init( async def _init(
self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data, cdn_redirect=None):
):
self.request = functions.upload.GetFileRequest( self.request = functions.upload.GetFileRequest(
file, offset=offset, limit=request_size) file, offset=offset, limit=request_size)
self._client = self.client
self._cdn_redirect = cdn_redirect
if cdn_redirect is not None:
self.request = functions.upload.GetCdnFileRequest(cdn_redirect.file_token, offset=offset, limit=request_size)
self._client = await self.client._get_cdn_client(cdn_redirect)
self.total = file_size self.total = file_size
self._stride = stride self._stride = stride
@ -41,7 +51,7 @@ class _DirectDownloadIter(RequestIter):
self._msg_data = msg_data self._msg_data = msg_data
self._timed_out = False self._timed_out = False
self._exported = dc_id and self.client.session.dc_id != dc_id self._exported = dc_id and self._client.session.dc_id != dc_id
if not self._exported: if not self._exported:
# The used sender will also change if ``FileMigrateError`` occurs # The used sender will also change if ``FileMigrateError`` occurs
self._sender = self.client._sender self._sender = self.client._sender
@ -73,10 +83,16 @@ class _DirectDownloadIter(RequestIter):
async def _request(self): async def _request(self):
try: try:
result = await self.client._call(self._sender, self.request) result = await self._client._call(self._sender, self.request)
self._timed_out = False self._timed_out = False
if isinstance(result, types.upload.FileCdnRedirect): if isinstance(result, types.upload.FileCdnRedirect):
raise NotImplementedError # TODO Implement if self.client._mb_entity_cache.self_bot:
raise ValueError('FileCdnRedirect but the GetCdnFileRequest API access for bot users is restricted. Try to change api_id to avoid FileCdnRedirect')
raise _CdnRedirect(result)
if isinstance(result, types.upload.CdnFileReuploadNeeded):
await self.client._call(self.client._sender, functions.upload.ReuploadCdnFileRequest(file_token=self._cdn_redirect.file_token, request_token=result.request_token))
result = await self._client._call(self._sender, self.request)
return result.bytes
else: else:
return result.bytes return result.bytes
@ -516,7 +532,9 @@ class DownloadMethods:
dc_id: int = None, dc_id: int = None,
key: bytes = None, key: bytes = None,
iv: bytes = None, iv: bytes = None,
msg_data: tuple = None) -> typing.Optional[bytes]: msg_data: tuple = None,
cdn_redirect: types.upload.FileCdnRedirect = None
) -> typing.Optional[bytes]:
if not part_size_kb: if not part_size_kb:
if not file_size: if not file_size:
part_size_kb = 64 # Reasonable default part_size_kb = 64 # Reasonable default
@ -543,7 +561,7 @@ class DownloadMethods:
try: try:
async for chunk in self._iter_download( async for chunk in self._iter_download(
input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data): input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data, cdn_redirect=cdn_redirect):
if iv and key: if iv and key:
chunk = AES.decrypt_ige(chunk, key, iv) chunk = AES.decrypt_ige(chunk, key, iv)
r = f.write(chunk) r = f.write(chunk)
@ -561,6 +579,20 @@ class DownloadMethods:
if in_memory: if in_memory:
return f.getvalue() return f.getvalue()
except _CdnRedirect as e:
self._log[__name__].info('FileCdnRedirect to CDN data center %s', e.cdn_redirect.dc_id)
return await self._download_file(
input_location=input_location,
file=file,
part_size_kb=part_size_kb,
file_size=file_size,
progress_callback=progress_callback,
dc_id=e.cdn_redirect.dc_id,
key=e.cdn_redirect.encryption_key,
iv=e.cdn_redirect.encryption_iv,
msg_data=msg_data,
cdn_redirect=e.cdn_redirect
)
finally: finally:
if isinstance(file, str) or in_memory: if isinstance(file, str) or in_memory:
f.close() f.close()
@ -682,7 +714,8 @@ class DownloadMethods:
request_size: int = MAX_CHUNK_SIZE, request_size: int = MAX_CHUNK_SIZE,
file_size: int = None, file_size: int = None,
dc_id: int = None, dc_id: int = None,
msg_data: tuple = None msg_data: tuple = None,
cdn_redirect: types.upload.FileCdnRedirect = None
): ):
info = utils._get_file_info(file) info = utils._get_file_info(file)
if info.dc_id is not None: if info.dc_id is not None:
@ -733,6 +766,7 @@ class DownloadMethods:
request_size=request_size, request_size=request_size,
file_size=file_size, file_size=file_size,
msg_data=msg_data, msg_data=msg_data,
cdn_redirect=cdn_redirect
) )
# endregion # endregion

View File

@ -401,6 +401,7 @@ class TelegramBaseClient(abc.ABC):
# Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders
self._borrowed_senders = {} self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock() self._borrow_sender_lock = asyncio.Lock()
self._exported_sessions = {}
self._loop = None # only used as a sanity check self._loop = None # only used as a sanity check
self._updates_error = None self._updates_error = None
@ -785,7 +786,8 @@ class TelegramBaseClient(abc.ABC):
if cdn and not self._cdn_config: if cdn and not self._cdn_config:
cls._cdn_config = await self(functions.help.GetCdnConfigRequest()) cls._cdn_config = await self(functions.help.GetCdnConfigRequest())
for pk in cls._cdn_config.public_keys: for pk in cls._cdn_config.public_keys:
rsa.add_key(pk.public_key) if pk.dc_id == dc_id:
rsa.add_key(pk.public_key, old=False)
try: try:
return next( return next(
@ -798,10 +800,13 @@ class TelegramBaseClient(abc.ABC):
'Failed to get DC %s (cdn = %s) with use_ipv6 = %s; retrying ignoring IPv6 check', 'Failed to get DC %s (cdn = %s) with use_ipv6 = %s; retrying ignoring IPv6 check',
dc_id, cdn, self._use_ipv6 dc_id, cdn, self._use_ipv6
) )
return next( try:
dc for dc in cls._config.dc_options return next(
if dc.id == dc_id and bool(dc.cdn) == cdn dc for dc in cls._config.dc_options
) if dc.id == dc_id and bool(dc.cdn) == cdn
)
except StopIteration:
raise ValueError(f'Failed to get DC {dc_id} (cdn = {cdn})')
async def _create_exported_sender(self: 'TelegramClient', dc_id): async def _create_exported_sender(self: 'TelegramClient', dc_id):
""" """
@ -889,8 +894,6 @@ class TelegramBaseClient(abc.ABC):
async def _get_cdn_client(self: 'TelegramClient', cdn_redirect): async def _get_cdn_client(self: 'TelegramClient', cdn_redirect):
"""Similar to ._borrow_exported_client, but for CDNs""" """Similar to ._borrow_exported_client, but for CDNs"""
# TODO Implement
raise NotImplementedError
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
@ -899,18 +902,22 @@ class TelegramBaseClient(abc.ABC):
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session
self._log[__name__].info('Creating new CDN client') self._log[__name__].info('Creating new CDN client')
client = TelegramBaseClient( client = self.__class__(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._proxy,
timeout=self._sender.connection.get_timeout() timeout=self._timeout,
loop=self.loop
) )
# This will make use of the new RSA keys for this specific CDN. session.auth_key = self._sender.auth_key
# await client._sender.connect(self._connection(
# We won't be calling GetConfigRequest because it's only called session.server_address,
# when needed by ._get_dc, and also it's static so it's likely session.port,
# set already. Avoid invoking non-CDN methods by not syncing updates. session.dc_id,
client.connect(_sync_updates=False) loggers=self._log,
proxy=self._proxy,
local_addr=self._local_addr
))
return client return client
# endregion # endregion