From de17a19168d2a1c69c7bc431f46b2c47b61787d3 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 15 Jul 2020 14:35:42 +0200 Subject: [PATCH] Improve upload_file by properly supporting streaming files --- telethon/client/uploads.py | 121 ++++++++++++++++++++++++------------- telethon/helpers.py | 8 +++ 2 files changed, 87 insertions(+), 42 deletions(-) diff --git a/telethon/client/uploads.py b/telethon/client/uploads.py index c8564ad1..47fb4b44 100644 --- a/telethon/client/uploads.py +++ b/telethon/client/uploads.py @@ -5,7 +5,6 @@ import os import pathlib import re import typing -import inspect from io import BytesIO from ..crypto import AES @@ -95,6 +94,7 @@ class UploadMethods: *, caption: typing.Union[str, typing.Sequence[str]] = None, force_document: bool = False, + file_size: int = None, clear_draft: bool = False, progress_callback: 'hints.ProgressCallback' = None, reply_to: 'hints.MessageIDLike' = None, @@ -175,6 +175,13 @@ class UploadMethods: the extension of an image file or a video file, it will be sent as such. Otherwise always as a document. + file_size (`int`, optional): + The size of the file to be uploaded if it needs to be uploaded, + which will be determined automatically if not specified. + + If the file size can't be determined beforehand, the entire + file will be read in-memory to find out how large it is. + clear_draft (`bool`, optional): Whether the existing draft should be cleared or not. @@ -358,6 +365,7 @@ class UploadMethods: file_handle, media, image = await self._file_to_media( file, force_document=force_document, + file_size=file_size, progress_callback=progress_callback, attributes=attributes, allow_cache=allow_cache, thumb=thumb, voice_note=voice_note, video_note=video_note, @@ -449,6 +457,7 @@ class UploadMethods: file: 'hints.FileLike', *, part_size_kb: float = None, + file_size: int = None, file_name: str = None, use_cache: type = None, key: bytes = None, @@ -480,6 +489,13 @@ class UploadMethods: Chunk size when uploading files. The larger, the less requests will be made (up to 512KB maximum). + file_size (`int`, optional): + The size of the file to be uploaded, which will be determined + automatically if not specified. + + If the file size can't be determined beforehand, the entire + file will be read in-memory to find out how large it is. + file_name (`str`, optional): The file name which will be used on the resulting InputFile. If not specified, the name will be taken from the ``file`` @@ -527,34 +543,42 @@ class UploadMethods: if not file_name and getattr(file, 'name', None): file_name = file.name - if isinstance(file, str): + if file_size is not None: + pass # do nothing as it's already kwown + elif isinstance(file, str): file_size = os.path.getsize(file) + stream = open(file, 'rb') + close_stream = True elif isinstance(file, bytes): file_size = len(file) + stream = io.BytesIO(file) + close_stream = True else: - # `aiofiles` shouldn't base `IOBase` because they change the - # methods' definition. `seekable` would be `async` but since - # we won't get to check that, there's no need to maybe-await. - if isinstance(file, io.IOBase) and file.seekable(): - pos = file.tell() + if not callable(getattr(file, 'read', None)): + raise TypeError('file description should have a `read` method') + + if callable(getattr(file, 'seekable', None)): + seekable = await helpers._maybe_await(file.seekable()) else: - pos = None + seekable = False - # TODO Don't load the entire file in memory always - data = file.read() - if inspect.isawaitable(data): - data = await data + if seekable: + pos = await helpers._maybe_await(file.tell()) + await helpers._maybe_await(file.seek(0, os.SEEK_END)) + file_size = await helpers._maybe_await(file.tell()) + await helpers._maybe_await(file.seek(pos, os.SEEK_SET)) - if pos is not None: - file.seek(pos) + stream = file + close_stream = False + else: + self._log[__name__].warning( + 'Could not determine file size beforehand so the entire ' + 'file will be read in-memory') - if not isinstance(data, bytes): - raise TypeError( - 'file descriptor returned {}, not bytes (you must ' - 'open the file in bytes mode)'.format(type(data))) - - file = data - file_size = len(file) + data = await helpers._maybe_await(file.read()) + stream = io.BytesIO(data) + close_stream = True + file_size = len(data) # File will now either be a string or bytes if not part_size_kb: @@ -584,35 +608,46 @@ class UploadMethods: # Determine whether the file is too big (over 10MB) or not # Telegram does make a distinction between smaller or larger files - is_large = file_size > 10 * 1024 * 1024 + is_big = file_size > 10 * 1024 * 1024 hash_md5 = hashlib.md5() - if not is_large: - # Calculate the MD5 hash before anything else. - # As this needs to be done always for small files, - # might as well do it before anything else and - # check the cache. - if isinstance(file, str): - with open(file, 'rb') as stream: - file = stream.read() - hash_md5.update(file) part_count = (file_size + part_size - 1) // part_size self._log[__name__].info('Uploading file of %d bytes in %d chunks of %d', file_size, part_count, part_size) - with open(file, 'rb') if isinstance(file, str) else BytesIO(file)\ - as stream: + pos = 0 + try: for part_index in range(part_count): # Read the file by in chunks of size part_size - part = stream.read(part_size) + part = await helpers._maybe_await(stream.read(part_size)) - # encryption part if needed + if not isinstance(part, bytes): + raise TypeError( + 'file descriptor returned {}, not bytes (you must ' + 'open the file in bytes mode)'.format(type(part))) + + # `file_size` could be wrong in which case `part` may not be + # `part_size` before reaching the end. + if len(part) != part_size and part_index < part_count - 1: + raise ValueError( + 'read less than {} before reaching the end; either ' + '`file_size` or `read` are wrong'.format(part_size)) + + pos += len(part) + + if not is_big: + # Bit odd that MD5 is only needed for small files and not + # big ones with more chance for corruption, but that's + # what Telegram wants. + hash_md5.update(part) + + # Encryption part if needed if key and iv: part = AES.encrypt_ige(part, key, iv) # The SavePartRequest is different depending on whether # the file is too large or not (over or less than 10MB) - if is_large: + if is_big: request = functions.upload.SaveBigFilePartRequest( file_id, part_index, part_count, part) else: @@ -624,14 +659,15 @@ class UploadMethods: self._log[__name__].debug('Uploaded %d/%d', part_index + 1, part_count) if progress_callback: - r = progress_callback(stream.tell(), file_size) - if inspect.isawaitable(r): - await r + await helpers._maybe_await(progress_callback(pos, file_size)) else: raise RuntimeError( 'Failed to upload file part {}.'.format(part_index)) + finally: + if close_stream: + await helpers._maybe_await(stream.close()) - if is_large: + if is_big: return types.InputFileBig(file_id, part_count, file_name) else: return custom.InputSizedFile( @@ -641,7 +677,7 @@ class UploadMethods: # endregion async def _file_to_media( - self, file, force_document=False, + self, file, force_document=False, file_size=None, progress_callback=None, attributes=None, thumb=None, allow_cache=True, voice_note=False, video_note=False, supports_streaming=False, mime_type=None, as_image=None): @@ -686,6 +722,7 @@ class UploadMethods: elif not isinstance(file, str) or os.path.isfile(file): file_handle = await self.upload_file( _resize_photo_if_needed(file, as_image), + file_size=file_size, progress_callback=progress_callback ) elif re.match('https?://', file): @@ -725,7 +762,7 @@ class UploadMethods: else: if isinstance(thumb, pathlib.Path): thumb = str(thumb.absolute()) - thumb = await self.upload_file(thumb) + thumb = await self.upload_file(thumb, file_size=file_size) media = types.InputMediaUploadedDocument( file=file_handle, diff --git a/telethon/helpers.py b/telethon/helpers.py index 664d504e..55eb1b79 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -3,6 +3,7 @@ import asyncio import enum import os import struct +import inspect from hashlib import sha1 @@ -107,6 +108,13 @@ def retry_range(retries): yield 1 + attempt +async def _maybe_await(value): + if inspect.isawaitable(value): + return await value + else: + return value + + async def _cancel(log, **tasks): """ Helper to cancel one or more tasks gracefully, logging exceptions.