diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 93453322..ab6d3bbb 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -158,9 +158,6 @@ class TelegramBareClient: # See https://core.telegram.org/api/invoking#saving-client-info. self._first_request = True - # Uploaded files cache so subsequent calls are instant - self._upload_cache = {} - # Constantly read for results and updates from within the main client, # if the user has left enabled such option. self._spawn_read_thread = spawn_read_thread @@ -639,6 +636,7 @@ class TelegramBareClient: file = file.read() file_size = len(file) + # File will now either be a string or bytes if not part_size_kb: part_size_kb = get_appropriated_part_size(file_size) @@ -649,18 +647,40 @@ class TelegramBareClient: if part_size % 1024 != 0: raise ValueError('The part size must be evenly divisible by 1024') + # Set a default file name if None was specified + file_id = utils.generate_random_long() + if not file_name: + if isinstance(file, str): + file_name = os.path.basename(file) + else: + file_name = str(file_id) + # Determine whether the file is too big (over 10MB) or not # Telegram does make a distinction between smaller or larger files is_large = file_size > 10 * 1024 * 1024 + if not is_large: + # Calculate the MD5 hash before anything else. + # As this needs to be done always for small files, + # might as well do it before anything else and + # check the cache. + if isinstance(file, str): + with open(file, 'rb') as stream: + file = stream.read() + hash_md5 = md5(file) + tuple_ = self.session.get_file(hash_md5.digest(), file_size) + if tuple_: + __log__.info('File was already cached, not uploading again') + return InputFile(name=file_name, + md5_checksum=tuple_[0], id=tuple_[2], parts=tuple_[3]) + else: + hash_md5 = None + part_count = (file_size + part_size - 1) // part_size - - file_id = utils.generate_random_long() - hash_md5 = md5() - __log__.info('Uploading file of %d bytes in %d chunks of %d', file_size, part_count, part_size) - stream = open(file, 'rb') if isinstance(file, str) else BytesIO(file) - try: + + with open(file, 'rb') if isinstance(file, str) else BytesIO(file) \ + as stream: for part_index in range(part_count): # Read the file by in chunks of size part_size part = stream.read(part_size) @@ -675,29 +695,19 @@ class TelegramBareClient: result = self(request) if result: - __log__.debug('Uploaded %d/%d', part_index, part_count) - if not is_large: - # No need to update the hash if it's a large file - hash_md5.update(part) - + __log__.debug('Uploaded %d/%d', part_index + 1, part_count) if progress_callback: progress_callback(stream.tell(), file_size) else: raise RuntimeError( 'Failed to upload file part {}.'.format(part_index)) - finally: - stream.close() - - # Set a default file name if None was specified - if not file_name: - if isinstance(file, str): - file_name = os.path.basename(file) - else: - file_name = str(file_id) if is_large: return InputFileBig(file_id, part_count, file_name) else: + self.session.cache_file( + hash_md5.digest(), file_size, file_id, part_count) + return InputFile(file_id, part_count, file_name, md5_checksum=hash_md5.hexdigest()) diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index 7d17cad1..7b8a84fa 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -759,13 +759,8 @@ class TelegramClient(TelegramBareClient): for ext in ('.png', '.jpg', '.gif', '.jpeg') ) - file_hash = hash(file) - if file_hash in self._upload_cache: - file_handle = self._upload_cache[file_hash] - else: - self._upload_cache[file_hash] = file_handle = self.upload_file( - file, progress_callback=progress_callback - ) + file_handle = self.upload_file( + file, progress_callback=progress_callback) if as_photo and not force_document: media = InputMediaUploadedPhoto(file_handle, caption) @@ -835,14 +830,6 @@ class TelegramClient(TelegramBareClient): reply_to=reply_to, is_voice_note=()) # empty tuple is enough - def clear_file_cache(self): - """Calls to .send_file() will cache the remote location of the - uploaded files so that subsequent files can be immediate, so - uploading the same file path will result in using the cached - version. To avoid this a call to this method should be made. - """ - self._upload_cache.clear() - # endregion # region Downloading media requests diff --git a/telethon/tl/session.py b/telethon/tl/session.py index 3fa13d23..59794f16 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -2,7 +2,6 @@ import json import os import platform import sqlite3 -import struct import time from base64 import b64decode from os.path import isfile as file_exists @@ -16,7 +15,7 @@ from ..tl.types import ( ) EXTENSION = '.session' -CURRENT_VERSION = 1 # database version +CURRENT_VERSION = 2 # database version class Session: @@ -93,6 +92,8 @@ class Session: version = c.fetchone()[0] if version != CURRENT_VERSION: self._upgrade_database(old=version) + c.execute("delete from version") + c.execute("insert into version values (?)", (CURRENT_VERSION,)) self.save() # These values will be saved @@ -125,6 +126,17 @@ class Session: name text ) without rowid""" ) + # Save file_size along with md5_digest + # to make collisions even more unlikely. + c.execute( + """create table sent_files ( + md5_digest blob, + file_size integer, + file_id integer, + part_count integer, + primary key(md5_digest, file_size) + ) without rowid""" + ) # Migrating from JSON -> new table and may have entities if entities: c.executemany( @@ -158,7 +170,17 @@ class Session: return [] # No entities def _upgrade_database(self, old): - pass + if old == 1: + self._conn.execute( + """create table sent_files ( + md5_digest blob, + file_size integer, + file_id integer, + part_count integer, + primary key(md5_digest, file_size) + ) without rowid""" + ) + old = 2 # Data from sessions should be kept as properties # not to fetch the database every time we need it @@ -370,3 +392,19 @@ class Session: return InputPeerChannel(i, h) else: raise ValueError('Could not find input entity with key ', key) + + # File processing + + def get_file(self, md5_digest, file_size): + return self._conn.execute( + 'select * from sent_files ' + 'where md5_digest = ? and file_size = ?', (md5_digest, file_size) + ).fetchone() + + def cache_file(self, md5_digest, file_size, file_id, part_count): + with self._db_lock: + self._conn.execute( + 'insert into sent_files values (?,?,?,?)', + (md5_digest, file_size, file_id, part_count) + ) + self.save()