mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-11 03:56:36 +03:00
Fix parallel downloads when using exported senders
This commit is contained in:
parent
90ea4ba8db
commit
491302bb32
|
@ -202,6 +202,7 @@ class DownloadMethods(UserMethods):
|
|||
|
||||
# The used sender will change if ``FileMigrateError`` occurs
|
||||
sender = self._sender
|
||||
exported = False
|
||||
input_location = utils.get_input_location(input_location)
|
||||
|
||||
__log__.info('Downloading file in chunks of %d bytes', part_size)
|
||||
|
@ -217,7 +218,8 @@ class DownloadMethods(UserMethods):
|
|||
raise NotImplementedError
|
||||
except errors.FileMigrateError as e:
|
||||
__log__.info('File lives in another DC')
|
||||
sender = await self._get_exported_sender(e.new_dc)
|
||||
sender = await self._borrow_exported_sender(e.new_dc)
|
||||
exported = True
|
||||
continue
|
||||
|
||||
offset += part_size
|
||||
|
@ -233,7 +235,9 @@ class DownloadMethods(UserMethods):
|
|||
if progress_callback:
|
||||
progress_callback(f.tell(), file_size)
|
||||
finally:
|
||||
if sender != self._sender:
|
||||
if exported:
|
||||
await self._return_exported_sender(sender)
|
||||
elif sender != self._sender:
|
||||
await sender.disconnect()
|
||||
if isinstance(file, str) or in_memory:
|
||||
f.close()
|
||||
|
|
|
@ -213,9 +213,11 @@ class TelegramBaseClient(abc.ABC):
|
|||
auto_reconnect_callback=self._handle_auto_reconnect
|
||||
)
|
||||
|
||||
# Cache :tl:`ExportedAuthorization` as ``dc_id: MTProtoState``
|
||||
# to easily import them when getting an exported sender.
|
||||
self._exported_auths = {}
|
||||
# Cache ``{dc_id: (n, MTProtoSender)}`` for all borrowed senders,
|
||||
# being ``n`` the amount of borrows a given sender has; once ``n``
|
||||
# reaches ``0`` it should be disconnected and removed.
|
||||
self._borrowed_senders = {}
|
||||
self._borrow_sender_lock = asyncio.Lock()
|
||||
|
||||
# Save whether the user is authorized here (a.k.a. logged in)
|
||||
self._authorized = None # None = We don't know yet
|
||||
|
@ -369,36 +371,65 @@ class TelegramBaseClient(abc.ABC):
|
|||
and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
|
||||
)
|
||||
|
||||
async def _get_exported_sender(self, dc_id):
|
||||
async def _create_exported_sender(self, dc_id):
|
||||
"""
|
||||
Returns a cached `MTProtoSender` for the given `dc_id`, or creates
|
||||
a new one if it doesn't exist yet, and imports a freshly exported
|
||||
authorization key for it to be usable.
|
||||
Creates a new exported `MTProtoSender` for the given `dc_id` and
|
||||
returns it. This method should be used by `_borrow_exported_sender`.
|
||||
"""
|
||||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization
|
||||
auth = self._exported_auths.get(dc_id)
|
||||
dc = await self._get_dc(dc_id)
|
||||
state = MTProtoState(auth)
|
||||
state = MTProtoState(None)
|
||||
# Can't reuse self._sender._connection as it has its own seqno.
|
||||
#
|
||||
# If one were to do that, Telegram would reset the connection
|
||||
# with no further clues.
|
||||
sender = MTProtoSender(state, self._connection.clone(), self._loop)
|
||||
await sender.connect(dc.ip_address, dc.port)
|
||||
if not auth:
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
|
||||
req = self._init_with(functions.auth.ImportAuthorizationRequest(
|
||||
id=auth.id, bytes=auth.bytes
|
||||
))
|
||||
await sender.send(req)
|
||||
self._exported_auths[dc_id] = sender.state.auth_key
|
||||
return sender
|
||||
|
||||
async def _borrow_exported_sender(self, dc_id):
|
||||
"""
|
||||
Borrows a connected `MTProtoSender` for the given `dc_id`.
|
||||
If it's not cached, creates a new one if it doesn't exist yet,
|
||||
and imports a freshly exported authorization key for it to be usable.
|
||||
|
||||
Once its job is over it should be `_return_exported_sender`.
|
||||
"""
|
||||
async with self._borrow_sender_lock:
|
||||
n, sender = self._borrowed_senders.get(dc_id, (0, None))
|
||||
if not sender:
|
||||
sender = await self._create_exported_sender(dc_id)
|
||||
sender.dc_id = dc_id
|
||||
|
||||
self._borrowed_senders[dc_id] = (n + 1, sender)
|
||||
|
||||
return sender
|
||||
|
||||
async def _return_exported_sender(self, sender):
|
||||
"""
|
||||
Returns a borrowed exported sender. If all borrows have
|
||||
been returned, the sender is cleanly disconnected.
|
||||
"""
|
||||
async with self._borrow_sender_lock:
|
||||
dc_id = sender.dc_id
|
||||
n, _ = self._borrowed_senders[dc_id]
|
||||
n -= 1
|
||||
if n > 0:
|
||||
self._borrowed_senders[dc_id] = (n, sender)
|
||||
else:
|
||||
__log__.info('Disconnecting borrowed sender for DC %d', dc_id)
|
||||
await sender.disconnect()
|
||||
del self._borrowed_senders[dc_id]
|
||||
|
||||
async def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._get_exported_client, but for CDNs"""
|
||||
"""Similar to ._borrow_exported_client, but for CDNs"""
|
||||
# TODO Implement
|
||||
raise NotImplementedError
|
||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
|
|
Loading…
Reference in New Issue
Block a user