Merge pull request #549 from LonamiWebs/file-cache

This commit is contained in:
Lonami 2018-01-18 20:52:33 +01:00 committed by GitHub
commit 7c55d42287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 382 additions and 305 deletions

View File

@ -3,19 +3,16 @@ import os
import threading import threading
import warnings import warnings
from datetime import timedelta, datetime from datetime import timedelta, datetime
from hashlib import md5
from io import BytesIO
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Lock from threading import Lock
from time import sleep from time import sleep
from . import helpers as utils, version from . import version
from .crypto import rsa, CdnDecrypter from .crypto import rsa
from .errors import ( from .errors import (
RPCError, BrokenAuthKeyError, ServerError, RPCError, BrokenAuthKeyError, ServerError, FloodWaitError,
FloodWaitError, FloodTestPhoneWaitError, FileMigrateError, FloodTestPhoneWaitError, TypeNotFoundError, UnauthorizedError,
TypeNotFoundError, UnauthorizedError, PhoneMigrateError, PhoneMigrateError, NetworkMigrateError, UserMigrateError
NetworkMigrateError, UserMigrateError
) )
from .network import authenticator, MtProtoSender, Connection, ConnectionMode from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session from .tl import TLObject, Session
@ -30,15 +27,8 @@ from .tl.functions.help import (
GetCdnConfigRequest, GetConfigRequest GetCdnConfigRequest, GetConfigRequest
) )
from .tl.functions.updates import GetStateRequest from .tl.functions.updates import GetStateRequest
from .tl.functions.upload import (
GetFileRequest, SaveBigFilePartRequest, SaveFilePartRequest
)
from .tl.types import InputFile, InputFileBig
from .tl.types.auth import ExportedAuthorization from .tl.types.auth import ExportedAuthorization
from .tl.types.upload import FileCdnRedirect
from .update_state import UpdateState from .update_state import UpdateState
from .utils import get_appropriated_part_size
DEFAULT_DC_ID = 4 DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -565,213 +555,6 @@ class TelegramBareClient:
# endregion # endregion
# region Uploading media
def upload_file(self,
file,
part_size_kb=None,
file_name=None,
allow_cache=True,
progress_callback=None):
"""Uploads the specified file and returns a handle (an instance
of InputFile or InputFileBig, as required) which can be later used.
Uploading a file will simply return a "handle" to the file stored
remotely in the Telegram servers, which can be later used on. This
will NOT upload the file to your own chat.
'file' may be either a file path, a byte array, or a stream.
Note that if the file is a stream it will need to be read
entirely into memory to tell its size first.
If 'progress_callback' is not None, it should be a function that
takes two parameters, (bytes_uploaded, total_bytes).
Default values for the optional parameters if left as None are:
part_size_kb = get_appropriated_part_size(file_size)
file_name = os.path.basename(file_path)
"""
if isinstance(file, (InputFile, InputFileBig)):
return file # Already uploaded
if isinstance(file, str):
file_size = os.path.getsize(file)
elif isinstance(file, bytes):
file_size = len(file)
else:
file = file.read()
file_size = len(file)
# File will now either be a string or bytes
if not part_size_kb:
part_size_kb = get_appropriated_part_size(file_size)
if part_size_kb > 512:
raise ValueError('The part size must be less or equal to 512KB')
part_size = int(part_size_kb * 1024)
if part_size % 1024 != 0:
raise ValueError('The part size must be evenly divisible by 1024')
# Set a default file name if None was specified
file_id = utils.generate_random_long()
if not file_name:
if isinstance(file, str):
file_name = os.path.basename(file)
else:
file_name = str(file_id)
# Determine whether the file is too big (over 10MB) or not
# Telegram does make a distinction between smaller or larger files
is_large = file_size > 10 * 1024 * 1024
if not is_large:
# Calculate the MD5 hash before anything else.
# As this needs to be done always for small files,
# might as well do it before anything else and
# check the cache.
if isinstance(file, str):
with open(file, 'rb') as stream:
file = stream.read()
hash_md5 = md5(file)
tuple_ = self.session.get_file(hash_md5.digest(), file_size)
if tuple_ and allow_cache:
__log__.info('File was already cached, not uploading again')
return InputFile(name=file_name,
md5_checksum=tuple_[0], id=tuple_[2], parts=tuple_[3])
elif tuple_ and not allow_cache:
self.session.clear_file(hash_md5.digest(), file_size)
else:
hash_md5 = None
part_count = (file_size + part_size - 1) // part_size
__log__.info('Uploading file of %d bytes in %d chunks of %d',
file_size, part_count, part_size)
with open(file, 'rb') if isinstance(file, str) else BytesIO(file) \
as stream:
for part_index in range(part_count):
# Read the file by in chunks of size part_size
part = stream.read(part_size)
# The SavePartRequest is different depending on whether
# the file is too large or not (over or less than 10MB)
if is_large:
request = SaveBigFilePartRequest(file_id, part_index,
part_count, part)
else:
request = SaveFilePartRequest(file_id, part_index, part)
result = self(request)
if result:
__log__.debug('Uploaded %d/%d', part_index + 1, part_count)
if progress_callback:
progress_callback(stream.tell(), file_size)
else:
raise RuntimeError(
'Failed to upload file part {}.'.format(part_index))
if is_large:
return InputFileBig(file_id, part_count, file_name)
else:
self.session.cache_file(
hash_md5.digest(), file_size, file_id, part_count)
return InputFile(file_id, part_count, file_name,
md5_checksum=hash_md5.hexdigest())
# endregion
# region Downloading media
def download_file(self,
input_location,
file,
part_size_kb=None,
file_size=None,
progress_callback=None):
"""Downloads the given InputFileLocation to file (a stream or str).
If 'progress_callback' is not None, it should be a function that
takes two parameters, (bytes_downloaded, total_bytes). Note that
'total_bytes' simply equals 'file_size', and may be None.
"""
if not part_size_kb:
if not file_size:
part_size_kb = 64 # Reasonable default
else:
part_size_kb = get_appropriated_part_size(file_size)
part_size = int(part_size_kb * 1024)
# https://core.telegram.org/api/files says:
# > part_size % 1024 = 0 (divisible by 1KB)
#
# But https://core.telegram.org/cdn (more recent) says:
# > limit must be divisible by 4096 bytes
# So we just stick to the 4096 limit.
if part_size % 4096 != 0:
raise ValueError('The part size must be evenly divisible by 4096.')
if isinstance(file, str):
# Ensure that we'll be able to download the media
utils.ensure_parent_dir_exists(file)
f = open(file, 'wb')
else:
f = file
# The used client will change if FileMigrateError occurs
client = self
cdn_decrypter = None
__log__.info('Downloading file in chunks of %d bytes', part_size)
try:
offset = 0
while True:
try:
if cdn_decrypter:
result = cdn_decrypter.get_file()
else:
result = client(GetFileRequest(
input_location, offset, part_size
))
if isinstance(result, FileCdnRedirect):
__log__.info('File lives in a CDN')
cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result), result
)
except FileMigrateError as e:
__log__.info('File lives in another DC')
client = self._get_exported_client(e.new_dc)
continue
offset += part_size
# If we have received no data (0 bytes), the file is over
# So there is nothing left to download and write
if not result.bytes:
# Return some extra information, unless it's a CDN file
return getattr(result, 'type', '')
f.write(result.bytes)
__log__.debug('Saved %d more bytes', len(result.bytes))
if progress_callback:
progress_callback(f.tell(), file_size)
finally:
if client != self:
client.disconnect()
if cdn_decrypter:
try:
cdn_decrypter.client.disconnect()
except:
pass
if isinstance(file, str):
f.close()
# endregion
# region Updates handling # region Updates handling
def sync_updates(self): def sync_updates(self):

View File

@ -1,11 +1,21 @@
import hashlib
import itertools import itertools
import logging
import os import os
import sys import sys
import time import time
from collections import OrderedDict, UserList from collections import OrderedDict, UserList
from datetime import datetime, timedelta from datetime import datetime, timedelta
from io import BytesIO
from mimetypes import guess_type from mimetypes import guess_type
from .crypto import CdnDecrypter
from .tl.custom import InputSizedFile
from .tl.functions.upload import (
SaveBigFilePartRequest, SaveFilePartRequest, GetFileRequest
)
from .tl.types.upload import FileCdnRedirect
try: try:
import socks import socks
except ImportError: except ImportError:
@ -16,7 +26,8 @@ from . import helpers, utils
from .errors import ( from .errors import (
RPCError, UnauthorizedError, PhoneCodeEmptyError, PhoneCodeExpiredError, RPCError, UnauthorizedError, PhoneCodeEmptyError, PhoneCodeExpiredError,
PhoneCodeHashEmptyError, PhoneCodeInvalidError, LocationInvalidError, PhoneCodeHashEmptyError, PhoneCodeInvalidError, LocationInvalidError,
SessionPasswordNeededError, FilePartMissingError) SessionPasswordNeededError, FileMigrateError
)
from .network import ConnectionMode from .network import ConnectionMode
from .tl import TLObject from .tl import TLObject
from .tl.custom import Draft, Dialog from .tl.custom import Draft, Dialog
@ -55,11 +66,14 @@ from .tl.types import (
UpdateNewChannelMessage, UpdateNewMessage, UpdateShortSentMessage, UpdateNewChannelMessage, UpdateNewMessage, UpdateShortSentMessage,
PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel, MessageEmpty, PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel, MessageEmpty,
ChatInvite, ChatInviteAlready, PeerChannel, Photo, InputPeerSelf, ChatInvite, ChatInviteAlready, PeerChannel, Photo, InputPeerSelf,
InputSingleMedia, InputMediaPhoto, InputPhoto InputSingleMedia, InputMediaPhoto, InputPhoto, InputFile, InputFileBig,
InputDocument, InputMediaDocument
) )
from .tl.types.messages import DialogsSlice from .tl.types.messages import DialogsSlice
from .extensions import markdown from .extensions import markdown
__log__ = logging.getLogger(__name__)
class TelegramClient(TelegramBareClient): class TelegramClient(TelegramBareClient):
""" """
@ -862,7 +876,9 @@ class TelegramClient(TelegramBareClient):
allow_cache (:obj:`bool`, optional): allow_cache (:obj:`bool`, optional):
Whether to allow using the cached version stored in the Whether to allow using the cached version stored in the
database or not. Defaults to ``True`` to avoid reuploads. database or not. Defaults to ``True`` to avoid re-uploads.
Must be ``False`` if you wish to use different attributes
or thumb than those that were used when the file was cached.
Kwargs: Kwargs:
If "is_voice_note" in kwargs, despite its value, and the file is If "is_voice_note" in kwargs, despite its value, and the file is
@ -879,8 +895,7 @@ class TelegramClient(TelegramBareClient):
if all(utils.is_image(x) for x in file): if all(utils.is_image(x) for x in file):
return self._send_album( return self._send_album(
entity, file, caption=caption, entity, file, caption=caption,
progress_callback=progress_callback, reply_to=reply_to, progress_callback=progress_callback, reply_to=reply_to
allow_cache=allow_cache
) )
# Not all are images, so send all the files one by one # Not all are images, so send all the files one by one
return [ return [
@ -892,10 +907,20 @@ class TelegramClient(TelegramBareClient):
) for x in file ) for x in file
] ]
as_image = utils.is_image(file) and not force_document
use_cache = InputPhoto if as_image else InputDocument
file_handle = self.upload_file( file_handle = self.upload_file(
file, progress_callback=progress_callback, allow_cache=allow_cache) file, progress_callback=progress_callback,
use_cache=use_cache if allow_cache else None
)
if utils.is_image(file) and not force_document: if isinstance(file_handle, use_cache):
# File was cached, so an instance of use_cache was returned
if as_image:
media = InputMediaPhoto(file_handle, caption)
else:
media = InputMediaDocument(file_handle, caption)
elif as_image:
media = InputMediaUploadedPhoto(file_handle, caption) media = InputMediaUploadedPhoto(file_handle, caption)
else: else:
mime_type = None mime_type = None
@ -951,19 +976,18 @@ class TelegramClient(TelegramBareClient):
media=media, media=media,
reply_to_msg_id=self._get_reply_to(reply_to) reply_to_msg_id=self._get_reply_to(reply_to)
) )
try: msg = self._get_response_message(request, self(request))
return self._get_response_message(request, self(request)) if msg and isinstance(file_handle, InputSizedFile):
except FilePartMissingError: # There was a response message and we didn't use cached
# After a while, cached files are invalidated and this # version, so cache whatever we just sent to the database.
# error is raised. The file needs to be uploaded again. md5, size = file_handle.md5, file_handle.size
if not allow_cache: if as_image:
raise to_cache = utils.get_input_photo(msg.media.photo)
return self.send_file( else:
entity, file, allow_cache=False, to_cache = utils.get_input_document(msg.media.document)
caption=caption, force_document=force_document, self.session.cache_file(md5, size, to_cache)
progress_callback=progress_callback, reply_to=reply_to,
attributes=attributes, thumb=thumb, **kwargs return msg
)
def send_voice_note(self, entity, file, caption='', progress_callback=None, def send_voice_note(self, entity, file, caption='', progress_callback=None,
reply_to=None): reply_to=None):
@ -974,41 +998,172 @@ class TelegramClient(TelegramBareClient):
is_voice_note=()) # empty tuple is enough is_voice_note=()) # empty tuple is enough
def _send_album(self, entity, files, caption='', def _send_album(self, entity, files, caption='',
progress_callback=None, reply_to=None, progress_callback=None, reply_to=None):
allow_cache=True):
"""Specialized version of .send_file for albums""" """Specialized version of .send_file for albums"""
# We don't care if the user wants to avoid cache, we will use it
# anyway. Why? The cached version will be exactly the same thing
# we need to produce right now to send albums (uploadMedia), and
# cache only makes a difference for documents where the user may
# want the attributes used on them to change. Caption's ignored.
entity = self.get_input_entity(entity) entity = self.get_input_entity(entity)
reply_to = self._get_reply_to(reply_to) reply_to = self._get_reply_to(reply_to)
try:
# Need to upload the media first # Need to upload the media first, but only if they're not cached yet
media = [ media = []
self(UploadMediaRequest(entity, InputMediaUploadedPhoto( for file in files:
self.upload_file(file, allow_cache=allow_cache), # fh will either be InputPhoto or a modified InputFile
caption=caption fh = self.upload_file(file, use_cache=InputPhoto)
))) if not isinstance(fh, InputPhoto):
for file in files input_photo = utils.get_input_photo(self(UploadMediaRequest(
] entity, media=InputMediaUploadedPhoto(fh, caption)
)).photo)
self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo
media.append(InputSingleMedia(InputMediaPhoto(fh, caption)))
# Now we can construct the multi-media request # Now we can construct the multi-media request
result = self(SendMultiMediaRequest( result = self(SendMultiMediaRequest(
entity, reply_to_msg_id=reply_to, multi_media=[ entity, reply_to_msg_id=reply_to, multi_media=media
InputSingleMedia(InputMediaPhoto(
InputPhoto(m.photo.id, m.photo.access_hash),
caption=caption
))
for m in media
]
)) ))
return [ return [
self._get_response_message(update.id, result) self._get_response_message(update.id, result)
for update in result.updates for update in result.updates
if isinstance(update, UpdateMessageID) if isinstance(update, UpdateMessageID)
] ]
except FilePartMissingError:
if not allow_cache: def upload_file(self,
raise file,
return self._send_album( part_size_kb=None,
entity, files, allow_cache=False, caption=caption, file_name=None,
progress_callback=progress_callback, reply_to=reply_to use_cache=None,
progress_callback=None):
"""
Uploads the specified file and returns a handle (an instance of
InputFile or InputFileBig, as required) which can be later used
before it expires (they are usable during less than a day).
Uploading a file will simply return a "handle" to the file stored
remotely in the Telegram servers, which can be later used on. This
will **not** upload the file to your own chat or any chat at all.
Args:
file (:obj:`str` | :obj:`bytes` | :obj:`file`):
The path of the file, byte array, or stream that will be sent.
Note that if a byte array or a stream is given, a filename
or its type won't be inferred, and it will be sent as an
"unnamed application/octet-stream".
Subsequent calls with the very same file will result in
immediate uploads, unless ``.clear_file_cache()`` is called.
part_size_kb (:obj:`int`, optional):
Chunk size when uploading files. The larger, the less
requests will be made (up to 512KB maximum).
file_name (:obj:`str`, optional):
The file name which will be used on the resulting InputFile.
If not specified, the name will be taken from the ``file``
and if this is not a ``str``, it will be ``"unnamed"``.
use_cache (:obj:`type`, optional):
The type of cache to use (currently either ``InputDocument``
or ``InputPhoto``). If present and the file is small enough
to need the MD5, it will be checked against the database,
and if a match is found, the upload won't be made. Instead,
an instance of type ``use_cache`` will be returned.
progress_callback (:obj:`callable`, optional):
A callback function accepting two parameters:
``(sent bytes, total)``.
Returns:
``InputFileBig`` if the file size is larger than 10MB,
``InputSizedFile`` (subclass of ``InputFile``) otherwise.
"""
if isinstance(file, (InputFile, InputFileBig)):
return file # Already uploaded
if isinstance(file, str):
file_size = os.path.getsize(file)
elif isinstance(file, bytes):
file_size = len(file)
else:
file = file.read()
file_size = len(file)
# File will now either be a string or bytes
if not part_size_kb:
part_size_kb = utils.get_appropriated_part_size(file_size)
if part_size_kb > 512:
raise ValueError('The part size must be less or equal to 512KB')
part_size = int(part_size_kb * 1024)
if part_size % 1024 != 0:
raise ValueError(
'The part size must be evenly divisible by 1024')
# Set a default file name if None was specified
file_id = helpers.generate_random_long()
if not file_name:
if isinstance(file, str):
file_name = os.path.basename(file)
else:
file_name = str(file_id)
# Determine whether the file is too big (over 10MB) or not
# Telegram does make a distinction between smaller or larger files
is_large = file_size > 10 * 1024 * 1024
hash_md5 = hashlib.md5()
if not is_large:
# Calculate the MD5 hash before anything else.
# As this needs to be done always for small files,
# might as well do it before anything else and
# check the cache.
if isinstance(file, str):
with open(file, 'rb') as stream:
file = stream.read()
hash_md5.update(file)
if use_cache:
cached = self.session.get_file(
hash_md5.digest(), file_size, cls=use_cache
)
if cached:
return cached
part_count = (file_size + part_size - 1) // part_size
__log__.info('Uploading file of %d bytes in %d chunks of %d',
file_size, part_count, part_size)
with open(file, 'rb') if isinstance(file, str) else BytesIO(file) \
as stream:
for part_index in range(part_count):
# Read the file by in chunks of size part_size
part = stream.read(part_size)
# The SavePartRequest is different depending on whether
# the file is too large or not (over or less than 10MB)
if is_large:
request = SaveBigFilePartRequest(file_id, part_index,
part_count, part)
else:
request = SaveFilePartRequest(file_id, part_index, part)
result = self(request)
if result:
__log__.debug('Uploaded %d/%d', part_index + 1,
part_count)
if progress_callback:
progress_callback(stream.tell(), file_size)
else:
raise RuntimeError(
'Failed to upload file part {}.'.format(part_index))
if is_large:
return InputFileBig(file_id, part_count, file_name)
else:
return InputSizedFile(
file_id, part_count, file_name, md5=hash_md5, size=file_size
) )
# endregion # endregion
@ -1292,6 +1447,113 @@ class TelegramClient(TelegramBareClient):
return result return result
i += 1 i += 1
def download_file(self,
input_location,
file,
part_size_kb=None,
file_size=None,
progress_callback=None):
"""
Downloads the given input location to a file.
Args:
input_location (:obj:`InputFileLocation`):
The file location from which the file will be downloaded.
file (:obj:`str` | :obj:`file`, optional):
The output file path, directory, or stream-like object.
If the path exists and is a file, it will be overwritten.
part_size_kb (:obj:`int`, optional):
Chunk size when downloading files. The larger, the less
requests will be made (up to 512KB maximum).
file_size (:obj:`int`, optional):
The file size that is about to be downloaded, if known.
Only used if ``progress_callback`` is specified.
progress_callback (:obj:`callable`, optional):
A callback function accepting two parameters:
``(downloaded bytes, total)``. Note that the
``total`` is the provided ``file_size``.
"""
if not part_size_kb:
if not file_size:
part_size_kb = 64 # Reasonable default
else:
part_size_kb = utils.get_appropriated_part_size(file_size)
part_size = int(part_size_kb * 1024)
# https://core.telegram.org/api/files says:
# > part_size % 1024 = 0 (divisible by 1KB)
#
# But https://core.telegram.org/cdn (more recent) says:
# > limit must be divisible by 4096 bytes
# So we just stick to the 4096 limit.
if part_size % 4096 != 0:
raise ValueError(
'The part size must be evenly divisible by 4096.')
if isinstance(file, str):
# Ensure that we'll be able to download the media
helpers.ensure_parent_dir_exists(file)
f = open(file, 'wb')
else:
f = file
# The used client will change if FileMigrateError occurs
client = self
cdn_decrypter = None
__log__.info('Downloading file in chunks of %d bytes', part_size)
try:
offset = 0
while True:
try:
if cdn_decrypter:
result = cdn_decrypter.get_file()
else:
result = client(GetFileRequest(
input_location, offset, part_size
))
if isinstance(result, FileCdnRedirect):
__log__.info('File lives in a CDN')
cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter(
client, self._get_cdn_client(result),
result
)
except FileMigrateError as e:
__log__.info('File lives in another DC')
client = self._get_exported_client(e.new_dc)
continue
offset += part_size
# If we have received no data (0 bytes), the file is over
# So there is nothing left to download and write
if not result.bytes:
# Return some extra information, unless it's a CDN file
return getattr(result, 'type', '')
f.write(result.bytes)
__log__.debug('Saved %d more bytes', len(result.bytes))
if progress_callback:
progress_callback(f.tell(), file_size)
finally:
if client != self:
client.disconnect()
if cdn_decrypter:
try:
cdn_decrypter.client.disconnect()
except:
pass
if isinstance(file, str):
f.close()
# endregion # endregion
# endregion # endregion

View File

@ -1,2 +1,3 @@
from .draft import Draft from .draft import Draft
from .dialog import Dialog from .dialog import Dialog
from .input_sized_file import InputSizedFile

View File

@ -0,0 +1,9 @@
from ..types import InputFile
class InputSizedFile(InputFile):
"""InputFile class with two extra parameters: md5 (digest) and size"""
def __init__(self, id_, parts, name, md5, size):
super().__init__(id_, parts, name, md5.hexdigest())
self.md5 = md5.digest()
self.size = size

View File

@ -5,6 +5,7 @@ import sqlite3
import struct import struct
import time import time
from base64 import b64decode from base64 import b64decode
from enum import Enum
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock
@ -12,11 +13,26 @@ from .. import utils
from ..tl import TLObject from ..tl import TLObject
from ..tl.types import ( from ..tl.types import (
PeerUser, PeerChat, PeerChannel, PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
) )
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 2 # database version CURRENT_VERSION = 3 # database version
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class Session: class Session:
@ -130,9 +146,10 @@ class Session:
"""sent_files ( """sent_files (
md5_digest blob, md5_digest blob,
file_size integer, file_size integer,
file_id integer, type integer,
part_count integer, id integer,
primary key(md5_digest, file_size) hash integer,
primary key(md5_digest, file_size, type)
)""" )"""
) )
c.execute("insert into version values (?)", (CURRENT_VERSION,)) c.execute("insert into version values (?)", (CURRENT_VERSION,))
@ -171,18 +188,22 @@ class Session:
def _upgrade_database(self, old): def _upgrade_database(self, old):
c = self._conn.cursor() c = self._conn.cursor()
if old == 1: # old == 1 doesn't have the old sent_files so no need to drop
self._create_table(c,"""sent_files ( if old == 2:
# Old cache from old sent_files lasts then a day anyway, drop
c.execute('drop table sent_files')
self._create_table(c, """sent_files (
md5_digest blob, md5_digest blob,
file_size integer, file_size integer,
file_id integer, type integer,
part_count integer, id integer,
primary key(md5_digest, file_size) hash integer,
primary key(md5_digest, file_size, type)
)""") )""")
old = 2
c.close() c.close()
def _create_table(self, c, *definitions): @staticmethod
def _create_table(c, *definitions):
""" """
Creates a table given its definition 'name (columns). Creates a table given its definition 'name (columns).
If the sqlite version is >= 3.8.2, it will use "without rowid". If the sqlite version is >= 3.8.2, it will use "without rowid".
@ -420,24 +441,25 @@ class Session:
# File processing # File processing
def get_file(self, md5_digest, file_size): def get_file(self, md5_digest, file_size, cls):
return self._conn.execute( tuple_ = self._conn.execute(
'select * from sent_files ' 'select id, hash from sent_files '
'where md5_digest = ? and file_size = ?', (md5_digest, file_size) 'where md5_digest = ? and file_size = ? and type = ?',
(md5_digest, file_size, _SentFileType.from_type(cls).value)
).fetchone() ).fetchone()
if tuple_:
# Both allowed classes have (id, access_hash) as parameters
return cls(tuple_[0], tuple_[1])
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
def cache_file(self, md5_digest, file_size, file_id, part_count):
with self._db_lock: with self._db_lock:
self._conn.execute( self._conn.execute(
'insert into sent_files values (?,?,?,?)', 'insert or replace into sent_files values (?,?,?,?,?)', (
(md5_digest, file_size, file_id, part_count) md5_digest, file_size,
) _SentFileType.from_type(type(instance)).value,
self.save() instance.id, instance.access_hash
))
def clear_file(self, md5_digest, file_size):
with self._db_lock:
self._conn.execute(
'delete from sent_files where '
'md5_digest = ? and file_size = ?', (md5_digest, file_size)
)
self.save() self.save()