Refetch msg if fileref expires while downloading docs

Closes #1301.
This commit is contained in:
Lonami Exo 2020-09-24 10:03:28 +02:00
parent 75fbd28d3e
commit c864ef7e16
2 changed files with 89 additions and 8 deletions

View File

@ -26,7 +26,7 @@ MAX_CHUNK_SIZE = 512 * 1024
class _DirectDownloadIter(RequestIter): class _DirectDownloadIter(RequestIter):
async def _init( async def _init(
self, file, dc_id, offset, stride, chunk_size, request_size, file_size self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data
): ):
self.request = functions.upload.GetFileRequest( self.request = functions.upload.GetFileRequest(
file, offset=offset, limit=request_size) file, offset=offset, limit=request_size)
@ -35,6 +35,7 @@ class _DirectDownloadIter(RequestIter):
self._stride = stride self._stride = stride
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._last_part = None self._last_part = None
self._msg_data = msg_data
self._exported = dc_id and self.client.session.dc_id != dc_id self._exported = dc_id and self.client.session.dc_id != dc_id
if not self._exported: if not self._exported:
@ -80,6 +81,29 @@ class _DirectDownloadIter(RequestIter):
self._exported = True self._exported = True
return await self._request() return await self._request()
except errors.FilerefUpgradeNeededError as e:
# Only implemented for documents which are the ones that may take that long to download
if not self._msg_data \
or not isinstance(self.request.location, types.InputDocumentFileLocation) \
or self.request.location.thumb_size != '':
raise
self.client._log[__name__].info('File ref expired during download; refetching message')
chat, msg_id = self._msg_data
msg = await self.client.get_messages(chat, ids=msg_id)
if not isinstance(msg.media, types.MessageMediaDocument):
raise
document = msg.media.document
# Message media may have been edited for something else
if document.id != self.request.location.id:
raise
self.request.location.file_reference = document.file_reference
return await self._request()
async def close(self): async def close(self):
if not self._sender: if not self._sender:
return return
@ -344,10 +368,16 @@ class DownloadMethods:
await client.download_media(message, progress_callback=callback) await client.download_media(message, progress_callback=callback)
""" """
# Downloading large documents may be slow enough to require a new file reference
# to be obtained mid-download. Store (input chat, message id) so that the message
# can be re-fetched.
msg_data = None
# TODO This won't work for messageService # TODO This won't work for messageService
if isinstance(message, types.Message): if isinstance(message, types.Message):
date = message.date date = message.date
media = message.media media = message.media
msg_data = (message.input_chat, message.id) if message.input_chat else None
else: else:
date = datetime.datetime.now() date = datetime.datetime.now()
media = message media = message
@ -365,7 +395,7 @@ class DownloadMethods:
) )
elif isinstance(media, (types.MessageMediaDocument, types.Document)): elif isinstance(media, (types.MessageMediaDocument, types.Document)):
return await self._download_document( return await self._download_document(
media, file, date, thumb, progress_callback media, file, date, thumb, progress_callback, msg_data
) )
elif isinstance(media, types.MessageMediaContact) and thumb is None: elif isinstance(media, types.MessageMediaContact) and thumb is None:
return self._download_contact( return self._download_contact(
@ -439,6 +469,29 @@ class DownloadMethods:
data = await client.download_file(input_file, bytes) data = await client.download_file(input_file, bytes)
print(data[:16]) print(data[:16])
""" """
return await self._download_file(
input_location,
file,
part_size_kb=part_size_kb,
file_size=file_size,
progress_callback=progress_callback,
dc_id=dc_id,
key=key,
iv=iv,
)
async def _download_file(
self: 'TelegramClient',
input_location: 'hints.FileLike',
file: 'hints.OutFileLike' = None,
*,
part_size_kb: float = None,
file_size: int = None,
progress_callback: 'hints.ProgressCallback' = None,
dc_id: int = None,
key: bytes = None,
iv: bytes = None,
msg_data: tuple = None) -> typing.Optional[bytes]:
if not part_size_kb: if not part_size_kb:
if not file_size: if not file_size:
part_size_kb = 64 # Reasonable default part_size_kb = 64 # Reasonable default
@ -464,8 +517,8 @@ class DownloadMethods:
f = file f = file
try: try:
async for chunk in self.iter_download( async for chunk in self._iter_download(
input_location, request_size=part_size, dc_id=dc_id): input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data):
if iv and key: if iv and key:
chunk = AES.decrypt_ige(chunk, key, iv) chunk = AES.decrypt_ige(chunk, key, iv)
r = f.write(chunk) r = f.write(chunk)
@ -582,6 +635,30 @@ class DownloadMethods:
await stream.close() await stream.close()
assert len(header) == 32 assert len(header) == 32
""" """
return self._iter_download(
file,
offset=offset,
stride=stride,
limit=limit,
chunk_size=chunk_size,
request_size=request_size,
file_size=file_size,
dc_id=dc_id,
)
def _iter_download(
self: 'TelegramClient',
file: 'hints.FileLike',
*,
offset: int = 0,
stride: int = None,
limit: int = None,
chunk_size: int = None,
request_size: int = MAX_CHUNK_SIZE,
file_size: int = None,
dc_id: int = None,
msg_data: tuple = None
):
info = utils._get_file_info(file) info = utils._get_file_info(file)
if info.dc_id is not None: if info.dc_id is not None:
dc_id = info.dc_id dc_id = info.dc_id
@ -628,7 +705,8 @@ class DownloadMethods:
stride=stride, stride=stride,
chunk_size=chunk_size, chunk_size=chunk_size,
request_size=request_size, request_size=request_size,
file_size=file_size file_size=file_size,
msg_data=msg_data,
) )
# endregion # endregion
@ -748,7 +826,7 @@ class DownloadMethods:
return kind, possible_names return kind, possible_names
async def _download_document( async def _download_document(
self, document, file, date, thumb, progress_callback): self, document, file, date, thumb, progress_callback, msg_data):
"""Specialized version of .download_media() for documents.""" """Specialized version of .download_media() for documents."""
if isinstance(document, types.MessageMediaDocument): if isinstance(document, types.MessageMediaDocument):
document = document.document document = document.document
@ -768,7 +846,7 @@ class DownloadMethods:
if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)): if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)):
return self._download_cached_photo_size(size, file) return self._download_cached_photo_size(size, file)
result = await self.download_file( result = await self._download_file(
types.InputDocumentFileLocation( types.InputDocumentFileLocation(
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
@ -777,7 +855,8 @@ class DownloadMethods:
), ),
file, file,
file_size=size.size if size else document.size, file_size=size.size if size else document.size,
progress_callback=progress_callback progress_callback=progress_callback,
msg_data=msg_data,
) )
return result if file is bytes else file return result if file is bytes else file

View File

@ -761,6 +761,8 @@ class Message(ChatGetter, SenderGetter, TLObject, abc.ABC):
with the ``message`` already set. with the ``message`` already set.
""" """
if self._client: if self._client:
# Passing the entire message is important, in case it has to be
# refetched for a fresh file reference.
return await self._client.download_media(self, *args, **kwargs) return await self._client.download_media(self, *args, **kwargs)
async def click(self, i=None, j=None, async def click(self, i=None, j=None,