diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index d2e84ee6..af86c0f1 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -633,13 +633,6 @@ class TelegramBareClient: with open(file, 'rb') as stream: file = stream.read() hash_md5 = md5(file) - tuple_ = self.session.get_file(hash_md5.digest(), file_size) - if tuple_ and allow_cache: - __log__.info('File was already cached, not uploading again') - return InputFile(name=file_name, - md5_checksum=tuple_[0], id=tuple_[2], parts=tuple_[3]) - elif tuple_ and not allow_cache: - self.session.clear_file(hash_md5.digest(), file_size) else: hash_md5 = None @@ -673,9 +666,6 @@ class TelegramBareClient: if is_large: return InputFileBig(file_id, part_count, file_name) else: - self.session.cache_file( - hash_md5.digest(), file_size, file_id, part_count) - return InputFile(file_id, part_count, file_name, md5_checksum=hash_md5.hexdigest()) diff --git a/telethon/tl/session.py b/telethon/tl/session.py index 1dbf99c5..5d89a5f7 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -5,6 +5,7 @@ import sqlite3 import struct import time from base64 import b64decode +from enum import Enum from os.path import isfile as file_exists from threading import Lock @@ -12,11 +13,26 @@ from .. import utils from ..tl import TLObject from ..tl.types import ( PeerUser, PeerChat, PeerChannel, - InputPeerUser, InputPeerChat, InputPeerChannel + InputPeerUser, InputPeerChat, InputPeerChannel, + InputPhoto, InputDocument ) EXTENSION = '.session' -CURRENT_VERSION = 2 # database version +CURRENT_VERSION = 3 # database version + + +class _SentFileType(Enum): + DOCUMENT = 0 + PHOTO = 1 + + @staticmethod + def from_type(cls): + if cls == InputDocument: + return _SentFileType.DOCUMENT + elif cls == InputPhoto: + return _SentFileType.PHOTO + else: + raise ValueError('The cls must be either InputDocument/InputPhoto') class Session: @@ -130,9 +146,10 @@ class Session: """sent_files ( md5_digest blob, file_size integer, - file_id integer, - part_count integer, - primary key(md5_digest, file_size) + type integer, + id integer, + hash integer, + primary key(md5_digest, file_size, type) )""" ) c.execute("insert into version values (?)", (CURRENT_VERSION,)) @@ -171,18 +188,22 @@ class Session: def _upgrade_database(self, old): c = self._conn.cursor() - if old == 1: - self._create_table(c,"""sent_files ( - md5_digest blob, - file_size integer, - file_id integer, - part_count integer, - primary key(md5_digest, file_size) - )""") - old = 2 + # old == 1 doesn't have the old sent_files so no need to drop + if old == 2: + # Old cache from old sent_files lasts then a day anyway, drop + c.execute('drop table sent_files') + self._create_table(c, """sent_files ( + md5_digest blob, + file_size integer, + type integer, + id integer, + hash integer, + primary key(md5_digest, file_size, type) + )""") c.close() - def _create_table(self, c, *definitions): + @staticmethod + def _create_table(c, *definitions): """ Creates a table given its definition 'name (columns). If the sqlite version is >= 3.8.2, it will use "without rowid". @@ -420,24 +441,25 @@ class Session: # File processing - def get_file(self, md5_digest, file_size): - return self._conn.execute( - 'select * from sent_files ' - 'where md5_digest = ? and file_size = ?', (md5_digest, file_size) + def get_file(self, md5_digest, file_size, cls): + tuple_ = self._conn.execute( + 'select id, hash from sent_files ' + 'where md5_digest = ? and file_size = ? and type = ?', + (md5_digest, file_size, _SentFileType.from_type(cls)) ).fetchone() + if tuple_: + # Both allowed classes have (id, access_hash) as parameters + return cls(tuple_[0], tuple_[1]) + + def cache_file(self, md5_digest, file_size, instance): + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError('Cannot cache %s instance' % type(instance)) - def cache_file(self, md5_digest, file_size, file_id, part_count): with self._db_lock: self._conn.execute( - 'insert into sent_files values (?,?,?,?)', - (md5_digest, file_size, file_id, part_count) - ) - self.save() - - def clear_file(self, md5_digest, file_size): - with self._db_lock: - self._conn.execute( - 'delete from sent_files where ' - 'md5_digest = ? and file_size = ?', (md5_digest, file_size) - ) + 'insert into sent_files values (?,?,?,?,?)', ( + md5_digest, file_size, + _SentFileType.from_type(type(instance)), + instance.id, instance.access_hash + )) self.save()