Improve upload_file by properly supporting streaming files

This commit is contained in:
Lonami Exo 2020-07-15 14:35:42 +02:00
parent bfb8de2736
commit de17a19168
2 changed files with 87 additions and 42 deletions

View File

@ -5,7 +5,6 @@ import os
import pathlib import pathlib
import re import re
import typing import typing
import inspect
from io import BytesIO from io import BytesIO
from ..crypto import AES from ..crypto import AES
@ -95,6 +94,7 @@ class UploadMethods:
*, *,
caption: typing.Union[str, typing.Sequence[str]] = None, caption: typing.Union[str, typing.Sequence[str]] = None,
force_document: bool = False, force_document: bool = False,
file_size: int = None,
clear_draft: bool = False, clear_draft: bool = False,
progress_callback: 'hints.ProgressCallback' = None, progress_callback: 'hints.ProgressCallback' = None,
reply_to: 'hints.MessageIDLike' = None, reply_to: 'hints.MessageIDLike' = None,
@ -175,6 +175,13 @@ class UploadMethods:
the extension of an image file or a video file, it will be the extension of an image file or a video file, it will be
sent as such. Otherwise always as a document. sent as such. Otherwise always as a document.
file_size (`int`, optional):
The size of the file to be uploaded if it needs to be uploaded,
which will be determined automatically if not specified.
If the file size can't be determined beforehand, the entire
file will be read in-memory to find out how large it is.
clear_draft (`bool`, optional): clear_draft (`bool`, optional):
Whether the existing draft should be cleared or not. Whether the existing draft should be cleared or not.
@ -358,6 +365,7 @@ class UploadMethods:
file_handle, media, image = await self._file_to_media( file_handle, media, image = await self._file_to_media(
file, force_document=force_document, file, force_document=force_document,
file_size=file_size,
progress_callback=progress_callback, progress_callback=progress_callback,
attributes=attributes, allow_cache=allow_cache, thumb=thumb, attributes=attributes, allow_cache=allow_cache, thumb=thumb,
voice_note=voice_note, video_note=video_note, voice_note=voice_note, video_note=video_note,
@ -449,6 +457,7 @@ class UploadMethods:
file: 'hints.FileLike', file: 'hints.FileLike',
*, *,
part_size_kb: float = None, part_size_kb: float = None,
file_size: int = None,
file_name: str = None, file_name: str = None,
use_cache: type = None, use_cache: type = None,
key: bytes = None, key: bytes = None,
@ -480,6 +489,13 @@ class UploadMethods:
Chunk size when uploading files. The larger, the less Chunk size when uploading files. The larger, the less
requests will be made (up to 512KB maximum). requests will be made (up to 512KB maximum).
file_size (`int`, optional):
The size of the file to be uploaded, which will be determined
automatically if not specified.
If the file size can't be determined beforehand, the entire
file will be read in-memory to find out how large it is.
file_name (`str`, optional): file_name (`str`, optional):
The file name which will be used on the resulting InputFile. The file name which will be used on the resulting InputFile.
If not specified, the name will be taken from the ``file`` If not specified, the name will be taken from the ``file``
@ -527,34 +543,42 @@ class UploadMethods:
if not file_name and getattr(file, 'name', None): if not file_name and getattr(file, 'name', None):
file_name = file.name file_name = file.name
if isinstance(file, str): if file_size is not None:
pass # do nothing as it's already kwown
elif isinstance(file, str):
file_size = os.path.getsize(file) file_size = os.path.getsize(file)
stream = open(file, 'rb')
close_stream = True
elif isinstance(file, bytes): elif isinstance(file, bytes):
file_size = len(file) file_size = len(file)
stream = io.BytesIO(file)
close_stream = True
else: else:
# `aiofiles` shouldn't base `IOBase` because they change the if not callable(getattr(file, 'read', None)):
# methods' definition. `seekable` would be `async` but since raise TypeError('file description should have a `read` method')
# we won't get to check that, there's no need to maybe-await.
if isinstance(file, io.IOBase) and file.seekable(): if callable(getattr(file, 'seekable', None)):
pos = file.tell() seekable = await helpers._maybe_await(file.seekable())
else: else:
pos = None seekable = False
# TODO Don't load the entire file in memory always if seekable:
data = file.read() pos = await helpers._maybe_await(file.tell())
if inspect.isawaitable(data): await helpers._maybe_await(file.seek(0, os.SEEK_END))
data = await data file_size = await helpers._maybe_await(file.tell())
await helpers._maybe_await(file.seek(pos, os.SEEK_SET))
if pos is not None: stream = file
file.seek(pos) close_stream = False
else:
self._log[__name__].warning(
'Could not determine file size beforehand so the entire '
'file will be read in-memory')
if not isinstance(data, bytes): data = await helpers._maybe_await(file.read())
raise TypeError( stream = io.BytesIO(data)
'file descriptor returned {}, not bytes (you must ' close_stream = True
'open the file in bytes mode)'.format(type(data))) file_size = len(data)
file = data
file_size = len(file)
# File will now either be a string or bytes # File will now either be a string or bytes
if not part_size_kb: if not part_size_kb:
@ -584,35 +608,46 @@ class UploadMethods:
# Determine whether the file is too big (over 10MB) or not # Determine whether the file is too big (over 10MB) or not
# Telegram does make a distinction between smaller or larger files # Telegram does make a distinction between smaller or larger files
is_large = file_size > 10 * 1024 * 1024 is_big = file_size > 10 * 1024 * 1024
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
if not is_large:
# Calculate the MD5 hash before anything else.
# As this needs to be done always for small files,
# might as well do it before anything else and
# check the cache.
if isinstance(file, str):
with open(file, 'rb') as stream:
file = stream.read()
hash_md5.update(file)
part_count = (file_size + part_size - 1) // part_size part_count = (file_size + part_size - 1) // part_size
self._log[__name__].info('Uploading file of %d bytes in %d chunks of %d', self._log[__name__].info('Uploading file of %d bytes in %d chunks of %d',
file_size, part_count, part_size) file_size, part_count, part_size)
with open(file, 'rb') if isinstance(file, str) else BytesIO(file)\ pos = 0
as stream: try:
for part_index in range(part_count): for part_index in range(part_count):
# Read the file by in chunks of size part_size # Read the file by in chunks of size part_size
part = stream.read(part_size) part = await helpers._maybe_await(stream.read(part_size))
# encryption part if needed if not isinstance(part, bytes):
raise TypeError(
'file descriptor returned {}, not bytes (you must '
'open the file in bytes mode)'.format(type(part)))
# `file_size` could be wrong in which case `part` may not be
# `part_size` before reaching the end.
if len(part) != part_size and part_index < part_count - 1:
raise ValueError(
'read less than {} before reaching the end; either '
'`file_size` or `read` are wrong'.format(part_size))
pos += len(part)
if not is_big:
# Bit odd that MD5 is only needed for small files and not
# big ones with more chance for corruption, but that's
# what Telegram wants.
hash_md5.update(part)
# Encryption part if needed
if key and iv: if key and iv:
part = AES.encrypt_ige(part, key, iv) part = AES.encrypt_ige(part, key, iv)
# The SavePartRequest is different depending on whether # The SavePartRequest is different depending on whether
# the file is too large or not (over or less than 10MB) # the file is too large or not (over or less than 10MB)
if is_large: if is_big:
request = functions.upload.SaveBigFilePartRequest( request = functions.upload.SaveBigFilePartRequest(
file_id, part_index, part_count, part) file_id, part_index, part_count, part)
else: else:
@ -624,14 +659,15 @@ class UploadMethods:
self._log[__name__].debug('Uploaded %d/%d', self._log[__name__].debug('Uploaded %d/%d',
part_index + 1, part_count) part_index + 1, part_count)
if progress_callback: if progress_callback:
r = progress_callback(stream.tell(), file_size) await helpers._maybe_await(progress_callback(pos, file_size))
if inspect.isawaitable(r):
await r
else: else:
raise RuntimeError( raise RuntimeError(
'Failed to upload file part {}.'.format(part_index)) 'Failed to upload file part {}.'.format(part_index))
finally:
if close_stream:
await helpers._maybe_await(stream.close())
if is_large: if is_big:
return types.InputFileBig(file_id, part_count, file_name) return types.InputFileBig(file_id, part_count, file_name)
else: else:
return custom.InputSizedFile( return custom.InputSizedFile(
@ -641,7 +677,7 @@ class UploadMethods:
# endregion # endregion
async def _file_to_media( async def _file_to_media(
self, file, force_document=False, self, file, force_document=False, file_size=None,
progress_callback=None, attributes=None, thumb=None, progress_callback=None, attributes=None, thumb=None,
allow_cache=True, voice_note=False, video_note=False, allow_cache=True, voice_note=False, video_note=False,
supports_streaming=False, mime_type=None, as_image=None): supports_streaming=False, mime_type=None, as_image=None):
@ -686,6 +722,7 @@ class UploadMethods:
elif not isinstance(file, str) or os.path.isfile(file): elif not isinstance(file, str) or os.path.isfile(file):
file_handle = await self.upload_file( file_handle = await self.upload_file(
_resize_photo_if_needed(file, as_image), _resize_photo_if_needed(file, as_image),
file_size=file_size,
progress_callback=progress_callback progress_callback=progress_callback
) )
elif re.match('https?://', file): elif re.match('https?://', file):
@ -725,7 +762,7 @@ class UploadMethods:
else: else:
if isinstance(thumb, pathlib.Path): if isinstance(thumb, pathlib.Path):
thumb = str(thumb.absolute()) thumb = str(thumb.absolute())
thumb = await self.upload_file(thumb) thumb = await self.upload_file(thumb, file_size=file_size)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
file=file_handle, file=file_handle,

View File

@ -3,6 +3,7 @@ import asyncio
import enum import enum
import os import os
import struct import struct
import inspect
from hashlib import sha1 from hashlib import sha1
@ -107,6 +108,13 @@ def retry_range(retries):
yield 1 + attempt yield 1 + attempt
async def _maybe_await(value):
if inspect.isawaitable(value):
return await value
else:
return value
async def _cancel(log, **tasks): async def _cancel(log, **tasks):
""" """
Helper to cancel one or more tasks gracefully, logging exceptions. Helper to cancel one or more tasks gracefully, logging exceptions.