Refetch msg if fileref expires while downloading docs

Closes #1301.
This commit is contained in:
Lonami Exo 2020-09-24 10:03:28 +02:00
parent 75fbd28d3e
commit c864ef7e16
2 changed files with 89 additions and 8 deletions

View File

@ -26,7 +26,7 @@ MAX_CHUNK_SIZE = 512 * 1024
class _DirectDownloadIter(RequestIter):
async def _init(
self, file, dc_id, offset, stride, chunk_size, request_size, file_size
self, file, dc_id, offset, stride, chunk_size, request_size, file_size, msg_data
):
self.request = functions.upload.GetFileRequest(
file, offset=offset, limit=request_size)
@ -35,6 +35,7 @@ class _DirectDownloadIter(RequestIter):
self._stride = stride
self._chunk_size = chunk_size
self._last_part = None
self._msg_data = msg_data
self._exported = dc_id and self.client.session.dc_id != dc_id
if not self._exported:
@ -80,6 +81,29 @@ class _DirectDownloadIter(RequestIter):
self._exported = True
return await self._request()
except errors.FilerefUpgradeNeededError as e:
# Only implemented for documents which are the ones that may take that long to download
if not self._msg_data \
or not isinstance(self.request.location, types.InputDocumentFileLocation) \
or self.request.location.thumb_size != '':
raise
self.client._log[__name__].info('File ref expired during download; refetching message')
chat, msg_id = self._msg_data
msg = await self.client.get_messages(chat, ids=msg_id)
if not isinstance(msg.media, types.MessageMediaDocument):
raise
document = msg.media.document
# Message media may have been edited for something else
if document.id != self.request.location.id:
raise
self.request.location.file_reference = document.file_reference
return await self._request()
async def close(self):
if not self._sender:
return
@ -344,10 +368,16 @@ class DownloadMethods:
await client.download_media(message, progress_callback=callback)
"""
# Downloading large documents may be slow enough to require a new file reference
# to be obtained mid-download. Store (input chat, message id) so that the message
# can be re-fetched.
msg_data = None
# TODO This won't work for messageService
if isinstance(message, types.Message):
date = message.date
media = message.media
msg_data = (message.input_chat, message.id) if message.input_chat else None
else:
date = datetime.datetime.now()
media = message
@ -365,7 +395,7 @@ class DownloadMethods:
)
elif isinstance(media, (types.MessageMediaDocument, types.Document)):
return await self._download_document(
media, file, date, thumb, progress_callback
media, file, date, thumb, progress_callback, msg_data
)
elif isinstance(media, types.MessageMediaContact) and thumb is None:
return self._download_contact(
@ -439,6 +469,29 @@ class DownloadMethods:
data = await client.download_file(input_file, bytes)
print(data[:16])
"""
return await self._download_file(
input_location,
file,
part_size_kb=part_size_kb,
file_size=file_size,
progress_callback=progress_callback,
dc_id=dc_id,
key=key,
iv=iv,
)
async def _download_file(
self: 'TelegramClient',
input_location: 'hints.FileLike',
file: 'hints.OutFileLike' = None,
*,
part_size_kb: float = None,
file_size: int = None,
progress_callback: 'hints.ProgressCallback' = None,
dc_id: int = None,
key: bytes = None,
iv: bytes = None,
msg_data: tuple = None) -> typing.Optional[bytes]:
if not part_size_kb:
if not file_size:
part_size_kb = 64 # Reasonable default
@ -464,8 +517,8 @@ class DownloadMethods:
f = file
try:
async for chunk in self.iter_download(
input_location, request_size=part_size, dc_id=dc_id):
async for chunk in self._iter_download(
input_location, request_size=part_size, dc_id=dc_id, msg_data=msg_data):
if iv and key:
chunk = AES.decrypt_ige(chunk, key, iv)
r = f.write(chunk)
@ -582,6 +635,30 @@ class DownloadMethods:
await stream.close()
assert len(header) == 32
"""
return self._iter_download(
file,
offset=offset,
stride=stride,
limit=limit,
chunk_size=chunk_size,
request_size=request_size,
file_size=file_size,
dc_id=dc_id,
)
def _iter_download(
self: 'TelegramClient',
file: 'hints.FileLike',
*,
offset: int = 0,
stride: int = None,
limit: int = None,
chunk_size: int = None,
request_size: int = MAX_CHUNK_SIZE,
file_size: int = None,
dc_id: int = None,
msg_data: tuple = None
):
info = utils._get_file_info(file)
if info.dc_id is not None:
dc_id = info.dc_id
@ -628,7 +705,8 @@ class DownloadMethods:
stride=stride,
chunk_size=chunk_size,
request_size=request_size,
file_size=file_size
file_size=file_size,
msg_data=msg_data,
)
# endregion
@ -748,7 +826,7 @@ class DownloadMethods:
return kind, possible_names
async def _download_document(
self, document, file, date, thumb, progress_callback):
self, document, file, date, thumb, progress_callback, msg_data):
"""Specialized version of .download_media() for documents."""
if isinstance(document, types.MessageMediaDocument):
document = document.document
@ -768,7 +846,7 @@ class DownloadMethods:
if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)):
return self._download_cached_photo_size(size, file)
result = await self.download_file(
result = await self._download_file(
types.InputDocumentFileLocation(
id=document.id,
access_hash=document.access_hash,
@ -777,7 +855,8 @@ class DownloadMethods:
),
file,
file_size=size.size if size else document.size,
progress_callback=progress_callback
progress_callback=progress_callback,
msg_data=msg_data,
)
return result if file is bytes else file

View File

@ -761,6 +761,8 @@ class Message(ChatGetter, SenderGetter, TLObject, abc.ABC):
with the ``message`` already set.
"""
if self._client:
# Passing the entire message is important, in case it has to be
# refetched for a fresh file reference.
return await self._client.download_media(self, *args, **kwargs)
async def click(self, i=None, j=None,