From c5e6f7e265227702ad39f3c9cf632524ff1274bf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Mar 2018 23:34:32 +0200 Subject: [PATCH 01/15] Split Session into three parts and make a module for sessions --- telethon/network/mtproto_sender.py | 4 +- telethon/sessions/__init__.py | 3 + telethon/sessions/abstract.py | 136 +++++++++ telethon/sessions/memory.py | 297 ++++++++++++++++++++ telethon/{session.py => sessions/sqlite.py} | 234 ++++----------- telethon/telegram_bare_client.py | 10 +- 6 files changed, 491 insertions(+), 193 deletions(-) create mode 100644 telethon/sessions/__init__.py create mode 100644 telethon/sessions/abstract.py create mode 100644 telethon/sessions/memory.py rename telethon/{session.py => sessions/sqlite.py} (59%) diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 43b5e803..cbcdc76d 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -402,13 +402,13 @@ class MtProtoSender: elif bad_msg.error_code == 32: # msg_seqno too low, so just pump it up by some "large" amount # TODO A better fix would be to start with a new fresh session ID - self.session._sequence += 64 + self.session.sequence += 64 __log__.info('Attempting to set the right higher sequence') self._resend_request(bad_msg.bad_msg_id) return True elif bad_msg.error_code == 33: # msg_seqno too high never seems to happen but just in case - self.session._sequence -= 16 + self.session.sequence -= 16 __log__.info('Attempting to set the right lower sequence') self._resend_request(bad_msg.bad_msg_id) return True diff --git a/telethon/sessions/__init__.py b/telethon/sessions/__init__.py new file mode 100644 index 00000000..af3423f3 --- /dev/null +++ b/telethon/sessions/__init__.py @@ -0,0 +1,3 @@ +from .abstract import Session +from .memory import MemorySession +from .sqlite import SQLiteSession diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py new file mode 100644 index 00000000..c7392ffc --- /dev/null +++ b/telethon/sessions/abstract.py @@ -0,0 +1,136 @@ +from abc import ABC, abstractmethod + + +class Session(ABC): + @abstractmethod + def clone(self): + raise NotImplementedError + + @abstractmethod + def set_dc(self, dc_id, server_address, port): + raise NotImplementedError + + @property + @abstractmethod + def server_address(self): + raise NotImplementedError + + @property + @abstractmethod + def port(self): + raise NotImplementedError + + @property + @abstractmethod + def auth_key(self): + raise NotImplementedError + + @auth_key.setter + @abstractmethod + def auth_key(self, value): + raise NotImplementedError + + @property + @abstractmethod + def time_offset(self): + raise NotImplementedError + + @time_offset.setter + @abstractmethod + def time_offset(self, value): + raise NotImplementedError + + @property + @abstractmethod + def salt(self): + raise NotImplementedError + + @salt.setter + @abstractmethod + def salt(self, value): + raise NotImplementedError + + @property + @abstractmethod + def device_model(self): + raise NotImplementedError + + @property + @abstractmethod + def system_version(self): + raise NotImplementedError + + @property + @abstractmethod + def app_version(self): + raise NotImplementedError + + @property + @abstractmethod + def lang_code(self): + raise NotImplementedError + + @property + @abstractmethod + def system_lang_code(self): + raise NotImplementedError + + @property + @abstractmethod + def report_errors(self): + raise NotImplementedError + + @property + @abstractmethod + def sequence(self): + raise NotImplementedError + + @property + @abstractmethod + def flood_sleep_threshold(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def save(self): + raise NotImplementedError + + @abstractmethod + def delete(self): + raise NotImplementedError + + @classmethod + @abstractmethod + def list_sessions(cls): + raise NotImplementedError + + @abstractmethod + def get_new_msg_id(self): + raise NotImplementedError + + @abstractmethod + def update_time_offset(self, correct_msg_id): + raise NotImplementedError + + @abstractmethod + def generate_sequence(self, content_related): + raise NotImplementedError + + @abstractmethod + def process_entities(self, tlo): + raise NotImplementedError + + @abstractmethod + def get_input_entity(self, key): + raise NotImplementedError + + @abstractmethod + def cache_file(self, md5_digest, file_size, instance): + raise NotImplementedError + + @abstractmethod + def get_file(self, md5_digest, file_size, cls): + raise NotImplementedError diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py new file mode 100644 index 00000000..66558829 --- /dev/null +++ b/telethon/sessions/memory.py @@ -0,0 +1,297 @@ +from enum import Enum +import time +import platform + +from .. import utils +from .abstract import Session +from ..tl import TLObject + +from ..tl.types import ( + PeerUser, PeerChat, PeerChannel, + InputPeerUser, InputPeerChat, InputPeerChannel, + InputPhoto, InputDocument +) + + +class _SentFileType(Enum): + DOCUMENT = 0 + PHOTO = 1 + + @staticmethod + def from_type(cls): + if cls == InputDocument: + return _SentFileType.DOCUMENT + elif cls == InputPhoto: + return _SentFileType.PHOTO + else: + raise ValueError('The cls must be either InputDocument/InputPhoto') + + +class MemorySession(Session): + def __init__(self): + self._dc_id = None + self._server_address = None + self._port = None + self._salt = None + self._auth_key = None + self._sequence = 0 + self._last_msg_id = 0 + self._time_offset = 0 + self._flood_sleep_threshold = 60 + + system = platform.uname() + self._device_model = system.system or 'Unknown' + self._system_version = system.release or '1.0' + self._app_version = '1.0' + self._lang_code = 'en' + self._system_lang_code = self.lang_code + self._report_errors = True + self._flood_sleep_threshold = 60 + + self._files = {} + self._entities = set() + + def clone(self): + cloned = MemorySession() + cloned._device_model = self.device_model + cloned._system_version = self.system_version + cloned._app_version = self.app_version + cloned._lang_code = self.lang_code + cloned._system_lang_code = self.system_lang_code + cloned._report_errors = self.report_errors + cloned._flood_sleep_threshold = self.flood_sleep_threshold + + def set_dc(self, dc_id, server_address, port): + self._dc_id = dc_id + self._server_address = server_address + self._port = port + + @property + def server_address(self): + return self._server_address + + @property + def port(self): + return self._port + + @property + def auth_key(self): + return self._auth_key + + @auth_key.setter + def auth_key(self, value): + self._auth_key = value + + @property + def time_offset(self): + return self._time_offset + + @time_offset.setter + def time_offset(self, value): + self._time_offset = value + + @property + def salt(self): + return self._salt + + @salt.setter + def salt(self, value): + self._salt = value + + @property + def device_model(self): + return self._device_model + + @property + def system_version(self): + return self._system_version + + @property + def app_version(self): + return self._app_version + + @property + def lang_code(self): + return self._lang_code + + @property + def system_lang_code(self): + return self._system_lang_code + + @property + def report_errors(self): + return self._report_errors + + @property + def sequence(self): + return self._sequence + + @property + def flood_sleep_threshold(self): + return self._flood_sleep_threshold + + def close(self): + pass + + def save(self): + pass + + def delete(self): + pass + + @classmethod + def list_sessions(cls): + raise NotImplementedError + + def get_new_msg_id(self): + """Generates a new unique message ID based on the current + time (in ms) since epoch""" + now = time.time() + self._time_offset + nanoseconds = int((now - int(now)) * 1e+9) + new_msg_id = (int(now) << 32) | (nanoseconds << 2) + + if self._last_msg_id >= new_msg_id: + new_msg_id = self._last_msg_id + 4 + + self._last_msg_id = new_msg_id + + return new_msg_id + + def update_time_offset(self, correct_msg_id): + now = int(time.time()) + correct = correct_msg_id >> 32 + self._time_offset = correct - now + self._last_msg_id = 0 + + def generate_sequence(self, content_related): + if content_related: + result = self._sequence * 2 + 1 + self._sequence += 1 + return result + else: + return self._sequence * 2 + + @staticmethod + def _entities_to_rows(tlo): + if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): + # This may be a list of users already for instance + entities = tlo + else: + entities = [] + if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats): + entities.extend(tlo.chats) + if hasattr(tlo, 'users') and utils.is_list_like(tlo.users): + entities.extend(tlo.users) + if not entities: + return + + rows = [] # Rows to add (id, hash, username, phone, name) + for e in entities: + if not isinstance(e, TLObject): + continue + try: + p = utils.get_input_peer(e, allow_self=False) + marked_id = utils.get_peer_id(p) + except ValueError: + continue + + if isinstance(p, (InputPeerUser, InputPeerChannel)): + if not p.access_hash: + # Some users and channels seem to be returned without + # an 'access_hash', meaning Telegram doesn't want you + # to access them. This is the reason behind ensuring + # that the 'access_hash' is non-zero. See issue #354. + # Note that this checks for zero or None, see #392. + continue + else: + p_hash = p.access_hash + elif isinstance(p, InputPeerChat): + p_hash = 0 + else: + continue + + username = getattr(e, 'username', None) or None + if username is not None: + username = username.lower() + phone = getattr(e, 'phone', None) + name = utils.get_display_name(e) or None + rows.append((marked_id, p_hash, username, phone, name)) + return rows + + def process_entities(self, tlo): + self._entities += set(self._entities_to_rows(tlo)) + + def get_entity_rows_by_phone(self, phone): + rows = [(id, hash) for id, hash, _, found_phone, _ + in self._entities if found_phone == phone] + return rows[0] if rows else None + + def get_entity_rows_by_username(self, username): + rows = [(id, hash) for id, hash, found_username, _, _ + in self._entities if found_username == username] + return rows[0] if rows else None + + def get_entity_rows_by_name(self, name): + rows = [(id, hash) for id, hash, _, _, found_name + in self._entities if found_name == name] + return rows[0] if rows else None + + def get_entity_rows_by_id(self, id): + rows = [(id, hash) for found_id, hash, _, _, _ + in self._entities if found_id == id] + return rows[0] if rows else None + + def get_input_entity(self, key): + try: + if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): + # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) + # We already have an Input version, so nothing else required + return key + # Try to early return if this key can be casted as input peer + return utils.get_input_peer(key) + except (AttributeError, TypeError): + # Not a TLObject or can't be cast into InputPeer + if isinstance(key, TLObject): + key = utils.get_peer_id(key) + + result = None + if isinstance(key, str): + phone = utils.parse_phone(key) + if phone: + result = self.get_entity_rows_by_phone(phone) + else: + username, _ = utils.parse_username(key) + if username: + result = self.get_entity_rows_by_username(username) + + if isinstance(key, int): + result = self.get_entity_rows_by_id(key) + + if not result and isinstance(key, str): + result = self.get_entity_rows_by_name(key) + + if result: + i, h = result # unpack resulting tuple + i, k = utils.resolve_id(i) # removes the mark and returns kind + if k == PeerUser: + return InputPeerUser(i, h) + elif k == PeerChat: + return InputPeerChat(i) + elif k == PeerChannel: + return InputPeerChannel(i, h) + else: + raise ValueError('Could not find input entity with key ', key) + + def cache_file(self, md5_digest, file_size, instance): + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError('Cannot cache %s instance' % type(instance)) + key = (md5_digest, file_size, _SentFileType.from_type(instance)) + value = (instance.id, instance.access_hash) + self._files[key] = value + + def get_file(self, md5_digest, file_size, cls): + key = (md5_digest, file_size, _SentFileType.from_type(cls)) + try: + return self._files[key] + except KeyError: + return None diff --git a/telethon/session.py b/telethon/sessions/sqlite.py similarity index 59% rename from telethon/session.py rename to telethon/sessions/sqlite.py index 6b374c39..66a0c887 100644 --- a/telethon/session.py +++ b/telethon/sessions/sqlite.py @@ -5,14 +5,15 @@ import sqlite3 import struct import time from base64 import b64decode -from enum import Enum from os.path import isfile as file_exists from threading import Lock, RLock -from . import utils -from .crypto import AuthKey -from .tl import TLObject -from .tl.types import ( +from .. import utils +from .abstract import Session +from .memory import MemorySession, _SentFileType +from ..crypto import AuthKey +from ..tl import TLObject +from ..tl.types import ( PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, InputPhoto, InputDocument @@ -22,21 +23,7 @@ EXTENSION = '.session' CURRENT_VERSION = 3 # database version -class _SentFileType(Enum): - DOCUMENT = 0 - PHOTO = 1 - - @staticmethod - def from_type(cls): - if cls == InputDocument: - return _SentFileType.DOCUMENT - elif cls == InputPhoto: - return _SentFileType.PHOTO - else: - raise ValueError('The cls must be either InputDocument/InputPhoto') - - -class Session: +class SQLiteSession(MemorySession): """This session contains the required information to login into your Telegram account. NEVER give the saved JSON file to anyone, since they would gain instant access to all your messages and contacts. @@ -44,7 +31,9 @@ class Session: If you think the session has been compromised, close all the sessions through an official Telegram client to revoke the authorization. """ + def __init__(self, session_id): + super().__init__() """session_user_id should either be a string or another Session. Note that if another session is given, only parameters like those required to init a connection will be copied. @@ -54,15 +43,15 @@ class Session: # For connection purposes if isinstance(session_id, Session): - self.device_model = session_id.device_model - self.system_version = session_id.system_version - self.app_version = session_id.app_version - self.lang_code = session_id.lang_code - self.system_lang_code = session_id.system_lang_code - self.lang_pack = session_id.lang_pack - self.report_errors = session_id.report_errors - self.save_entities = session_id.save_entities - self.flood_sleep_threshold = session_id.flood_sleep_threshold + self._device_model = session_id.device_model + self._system_version = session_id.system_version + self._app_version = session_id.app_version + self._lang_code = session_id.lang_code + self._system_lang_code = session_id.system_lang_code + self._report_errors = session_id.report_errors + self._flood_sleep_threshold = session_id.flood_sleep_threshold + if isinstance(session_id, SQLiteSession): + self.save_entities = session_id.save_entities else: # str / None if session_id: self.filename = session_id @@ -70,15 +59,14 @@ class Session: self.filename += EXTENSION system = platform.uname() - self.device_model = system.system or 'Unknown' - self.system_version = system.release or '1.0' - self.app_version = '1.0' # '0' will provoke error - self.lang_code = 'en' - self.system_lang_code = self.lang_code - self.lang_pack = '' - self.report_errors = True + self._device_model = system.system or 'Unknown' + self._system_version = system.release or '1.0' + self._app_version = '1.0' # '0' will provoke error + self._lang_code = 'en' + self._system_lang_code = self.lang_code + self._report_errors = True self.save_entities = True - self.flood_sleep_threshold = 60 + self._flood_sleep_threshold = 60 self.id = struct.unpack('q', os.urandom(8))[0] self._sequence = 0 @@ -163,6 +151,9 @@ class Session: c.close() self.save() + def clone(self): + return SQLiteSession(self) + def _check_migrate_json(self): if file_exists(self.filename): try: @@ -233,19 +224,7 @@ class Session: self._auth_key = None c.close() - @property - def server_address(self): - return self._server_address - - @property - def port(self): - return self._port - - @property - def auth_key(self): - return self._auth_key - - @auth_key.setter + @Session.auth_key.setter def auth_key(self, value): self._auth_key = value self._update_session_table() @@ -298,53 +277,14 @@ class Session: except OSError: return False - @staticmethod - def list_sessions(): + @classmethod + def list_sessions(cls): """Lists all the sessions of the users who have ever connected using this client and never logged out """ return [os.path.splitext(os.path.basename(f))[0] for f in os.listdir('.') if f.endswith(EXTENSION)] - def generate_sequence(self, content_related): - """Thread safe method to generates the next sequence number, - based on whether it was confirmed yet or not. - - Note that if confirmed=True, the sequence number - will be increased by one too - """ - with self._seq_no_lock: - if content_related: - result = self._sequence * 2 + 1 - self._sequence += 1 - return result - else: - return self._sequence * 2 - - def get_new_msg_id(self): - """Generates a new unique message ID based on the current - time (in ms) since epoch""" - # Refer to mtproto_plain_sender.py for the original method - now = time.time() + self.time_offset - nanoseconds = int((now - int(now)) * 1e+9) - # "message identifiers are divisible by 4" - new_msg_id = (int(now) << 32) | (nanoseconds << 2) - - with self._msg_id_lock: - if self._last_msg_id >= new_msg_id: - new_msg_id = self._last_msg_id + 4 - - self._last_msg_id = new_msg_id - - return new_msg_id - - def update_time_offset(self, correct_msg_id): - """Updates the time offset based on a known correct message ID""" - now = int(time.time()) - correct = correct_msg_id >> 32 - self.time_offset = correct - now - self._last_msg_id = 0 - # Entity processing def process_entities(self, tlo): @@ -356,49 +296,7 @@ class Session: if not self.save_entities: return - if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): - # This may be a list of users already for instance - entities = tlo - else: - entities = [] - if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats): - entities.extend(tlo.chats) - if hasattr(tlo, 'users') and utils.is_list_like(tlo.users): - entities.extend(tlo.users) - if not entities: - return - - rows = [] # Rows to add (id, hash, username, phone, name) - for e in entities: - if not isinstance(e, TLObject): - continue - try: - p = utils.get_input_peer(e, allow_self=False) - marked_id = utils.get_peer_id(p) - except ValueError: - continue - - if isinstance(p, (InputPeerUser, InputPeerChannel)): - if not p.access_hash: - # Some users and channels seem to be returned without - # an 'access_hash', meaning Telegram doesn't want you - # to access them. This is the reason behind ensuring - # that the 'access_hash' is non-zero. See issue #354. - # Note that this checks for zero or None, see #392. - continue - else: - p_hash = p.access_hash - elif isinstance(p, InputPeerChat): - p_hash = 0 - else: - continue - - username = getattr(e, 'username', None) or None - if username is not None: - username = username.lower() - phone = getattr(e, 'phone', None) - name = utils.get_display_name(e) or None - rows.append((marked_id, p_hash, username, phone, name)) + rows = self._entities_to_rows(tlo) if not rows: return @@ -408,62 +306,26 @@ class Session: ) self.save() - def get_input_entity(self, key): - """Parses the given string, integer or TLObject key into a - marked entity ID, which is then used to fetch the hash - from the database. - - If a callable key is given, every row will be fetched, - and passed as a tuple to a function, that should return - a true-like value when the desired row is found. - - Raises ValueError if it cannot be found. - """ - try: - if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): - # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) - # We already have an Input version, so nothing else required - return key - # Try to early return if this key can be casted as input peer - return utils.get_input_peer(key) - except (AttributeError, TypeError): - # Not a TLObject or can't be cast into InputPeer - if isinstance(key, TLObject): - key = utils.get_peer_id(key) - + def _fetchone_entity(self, query, args): c = self._cursor() - if isinstance(key, str): - phone = utils.parse_phone(key) - if phone: - c.execute('select id, hash from entities where phone=?', - (phone,)) - else: - username, _ = utils.parse_username(key) - if username: - c.execute('select id, hash from entities where username=?', + c.execute(query, args) + return c.fetchone() + + def get_entity_rows_by_phone(self, phone): + return self._fetchone_entity( + 'select id, hash from entities where phone=?', (phone,)) + + def get_entity_rows_by_username(self, username): + self._fetchone_entity('select id, hash from entities where username=?', (username,)) - if isinstance(key, int): - c.execute('select id, hash from entities where id=?', (key,)) + def get_entity_rows_by_name(self, name): + self._fetchone_entity('select id, hash from entities where name=?', + (name,)) - result = c.fetchone() - if not result and isinstance(key, str): - # Try exact match by name if phone/username failed - c.execute('select id, hash from entities where name=?', (key,)) - result = c.fetchone() - - c.close() - if result: - i, h = result # unpack resulting tuple - i, k = utils.resolve_id(i) # removes the mark and returns kind - if k == PeerUser: - return InputPeerUser(i, h) - elif k == PeerChat: - return InputPeerChat(i) - elif k == PeerChannel: - return InputPeerChannel(i, h) - else: - raise ValueError('Could not find input entity with key ', key) + def get_entity_rows_by_id(self, id): + self._fetchone_entity('select id, hash from entities where id=?', + (id,)) # File processing diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 8a15476e..3a5b2bd0 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -14,7 +14,7 @@ from .errors import ( PhoneMigrateError, NetworkMigrateError, UserMigrateError ) from .network import authenticator, MtProtoSender, Connection, ConnectionMode -from .session import Session +from .sessions import Session, SQLiteSession from .tl import TLObject from .tl.all_tlobjects import LAYER from .tl.functions import ( @@ -81,10 +81,10 @@ class TelegramBareClient: "Refer to telethon.rtfd.io for more information.") self._use_ipv6 = use_ipv6 - + # Determine what session object we have if isinstance(session, str) or session is None: - session = Session(session) + session = SQLiteSession(session) elif not isinstance(session, Session): raise TypeError( 'The given session must be a str or a Session instance.' @@ -361,7 +361,7 @@ class TelegramBareClient: # # Construct this session with the connection parameters # (system version, device model...) from the current one. - session = Session(self.session) + session = self.session.clone() session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[dc_id] = session @@ -387,7 +387,7 @@ class TelegramBareClient: session = self._exported_sessions.get(cdn_redirect.dc_id) if not session: dc = self._get_dc(cdn_redirect.dc_id, cdn=True) - session = Session(self.session) + session = self.session.clone() session.set_dc(dc.id, dc.ip_address, dc.port) self._exported_sessions[cdn_redirect.dc_id] = session From 4c64d53e7178b3555b9c9346b93c7b1f4ab3366f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 11:10:11 +0200 Subject: [PATCH 02/15] Move non-persistent stuff to base Session class --- telethon/sessions/abstract.py | 144 ++++++++++++++++++++-------------- telethon/sessions/memory.py | 62 +-------------- 2 files changed, 87 insertions(+), 119 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index c7392ffc..ff0fd16d 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -1,10 +1,32 @@ from abc import ABC, abstractmethod +import time +import platform class Session(ABC): - @abstractmethod + def __init__(self): + self._sequence = 0 + self._last_msg_id = 0 + self._time_offset = 0 + + system = platform.uname() + self._device_model = system.system or 'Unknown' + self._system_version = system.release or '1.0' + self._app_version = '1.0' + self._lang_code = 'en' + self._system_lang_code = self.lang_code + self._report_errors = True + self._flood_sleep_threshold = 60 + def clone(self): - raise NotImplementedError + cloned = self.__class__() + cloned._device_model = self.device_model + cloned._system_version = self.system_version + cloned._app_version = self.app_version + cloned._lang_code = self.lang_code + cloned._system_lang_code = self.system_lang_code + cloned._report_errors = self.report_errors + cloned._flood_sleep_threshold = self.flood_sleep_threshold @abstractmethod def set_dc(self, dc_id, server_address, port): @@ -31,14 +53,12 @@ class Session(ABC): raise NotImplementedError @property - @abstractmethod def time_offset(self): - raise NotImplementedError + return self._time_offset @time_offset.setter - @abstractmethod def time_offset(self, value): - raise NotImplementedError + self._time_offset = value @property @abstractmethod @@ -50,46 +70,6 @@ class Session(ABC): def salt(self, value): raise NotImplementedError - @property - @abstractmethod - def device_model(self): - raise NotImplementedError - - @property - @abstractmethod - def system_version(self): - raise NotImplementedError - - @property - @abstractmethod - def app_version(self): - raise NotImplementedError - - @property - @abstractmethod - def lang_code(self): - raise NotImplementedError - - @property - @abstractmethod - def system_lang_code(self): - raise NotImplementedError - - @property - @abstractmethod - def report_errors(self): - raise NotImplementedError - - @property - @abstractmethod - def sequence(self): - raise NotImplementedError - - @property - @abstractmethod - def flood_sleep_threshold(self): - raise NotImplementedError - @abstractmethod def close(self): raise NotImplementedError @@ -107,18 +87,6 @@ class Session(ABC): def list_sessions(cls): raise NotImplementedError - @abstractmethod - def get_new_msg_id(self): - raise NotImplementedError - - @abstractmethod - def update_time_offset(self, correct_msg_id): - raise NotImplementedError - - @abstractmethod - def generate_sequence(self, content_related): - raise NotImplementedError - @abstractmethod def process_entities(self, tlo): raise NotImplementedError @@ -134,3 +102,63 @@ class Session(ABC): @abstractmethod def get_file(self, md5_digest, file_size, cls): raise NotImplementedError + + @property + def device_model(self): + return self._device_model + + @property + def system_version(self): + return self._system_version + + @property + def app_version(self): + return self._app_version + + @property + def lang_code(self): + return self._lang_code + + @property + def system_lang_code(self): + return self._system_lang_code + + @property + def report_errors(self): + return self._report_errors + + @property + def flood_sleep_threshold(self): + return self._flood_sleep_threshold + + @property + def sequence(self): + return self._sequence + + def get_new_msg_id(self): + """Generates a new unique message ID based on the current + time (in ms) since epoch""" + now = time.time() + self._time_offset + nanoseconds = int((now - int(now)) * 1e+9) + new_msg_id = (int(now) << 32) | (nanoseconds << 2) + + if self._last_msg_id >= new_msg_id: + new_msg_id = self._last_msg_id + 4 + + self._last_msg_id = new_msg_id + + return new_msg_id + + def update_time_offset(self, correct_msg_id): + now = int(time.time()) + correct = correct_msg_id >> 32 + self._time_offset = correct - now + self._last_msg_id = 0 + + def generate_sequence(self, content_related): + if content_related: + result = self._sequence * 2 + 1 + self._sequence += 1 + return result + else: + return self._sequence * 2 diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 66558829..09aa1fa0 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -1,6 +1,4 @@ from enum import Enum -import time -import platform from .. import utils from .abstract import Session @@ -29,38 +27,16 @@ class _SentFileType(Enum): class MemorySession(Session): def __init__(self): + super().__init__() self._dc_id = None self._server_address = None self._port = None self._salt = None self._auth_key = None - self._sequence = 0 - self._last_msg_id = 0 - self._time_offset = 0 - self._flood_sleep_threshold = 60 - - system = platform.uname() - self._device_model = system.system or 'Unknown' - self._system_version = system.release or '1.0' - self._app_version = '1.0' - self._lang_code = 'en' - self._system_lang_code = self.lang_code - self._report_errors = True - self._flood_sleep_threshold = 60 self._files = {} self._entities = set() - def clone(self): - cloned = MemorySession() - cloned._device_model = self.device_model - cloned._system_version = self.system_version - cloned._app_version = self.app_version - cloned._lang_code = self.lang_code - cloned._system_lang_code = self.system_lang_code - cloned._report_errors = self.report_errors - cloned._flood_sleep_threshold = self.flood_sleep_threshold - def set_dc(self, dc_id, server_address, port): self._dc_id = dc_id self._server_address = server_address @@ -82,14 +58,6 @@ class MemorySession(Session): def auth_key(self, value): self._auth_key = value - @property - def time_offset(self): - return self._time_offset - - @time_offset.setter - def time_offset(self, value): - self._time_offset = value - @property def salt(self): return self._salt @@ -143,34 +111,6 @@ class MemorySession(Session): def list_sessions(cls): raise NotImplementedError - def get_new_msg_id(self): - """Generates a new unique message ID based on the current - time (in ms) since epoch""" - now = time.time() + self._time_offset - nanoseconds = int((now - int(now)) * 1e+9) - new_msg_id = (int(now) << 32) | (nanoseconds << 2) - - if self._last_msg_id >= new_msg_id: - new_msg_id = self._last_msg_id + 4 - - self._last_msg_id = new_msg_id - - return new_msg_id - - def update_time_offset(self, correct_msg_id): - now = int(time.time()) - correct = correct_msg_id >> 32 - self._time_offset = correct - now - self._last_msg_id = 0 - - def generate_sequence(self, content_related): - if content_related: - result = self._sequence * 2 + 1 - self._sequence += 1 - return result - else: - return self._sequence * 2 - @staticmethod def _entities_to_rows(tlo): if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): From df3faaeb7f4f2694d9e68074fbb91314dad21af7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 11:11:59 +0200 Subject: [PATCH 03/15] Fix abstract Session method ordering --- telethon/sessions/abstract.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index ff0fd16d..ec4f649f 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -52,14 +52,6 @@ class Session(ABC): def auth_key(self, value): raise NotImplementedError - @property - def time_offset(self): - return self._time_offset - - @time_offset.setter - def time_offset(self, value): - self._time_offset = value - @property @abstractmethod def salt(self): @@ -127,6 +119,14 @@ class Session(ABC): def report_errors(self): return self._report_errors + @property + def time_offset(self): + return self._time_offset + + @time_offset.setter + def time_offset(self, value): + self._time_offset = value + @property def flood_sleep_threshold(self): return self._flood_sleep_threshold From d9a73744a49031d823fdd1da671476ddfc56c9d8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 12:36:39 +0200 Subject: [PATCH 04/15] Remove old sqlite session variables and clone code --- telethon/sessions/sqlite.py | 47 +++++++------------------------------ 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 66a0c887..64f3cbf6 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -32,7 +32,7 @@ class SQLiteSession(MemorySession): through an official Telegram client to revoke the authorization. """ - def __init__(self, session_id): + def __init__(self, session_id=None): super().__init__() """session_user_id should either be a string or another Session. Note that if another session is given, only parameters like @@ -40,51 +40,20 @@ class SQLiteSession(MemorySession): """ # These values will NOT be saved self.filename = ':memory:' + self.save_entities = True - # For connection purposes - if isinstance(session_id, Session): - self._device_model = session_id.device_model - self._system_version = session_id.system_version - self._app_version = session_id.app_version - self._lang_code = session_id.lang_code - self._system_lang_code = session_id.system_lang_code - self._report_errors = session_id.report_errors - self._flood_sleep_threshold = session_id.flood_sleep_threshold - if isinstance(session_id, SQLiteSession): - self.save_entities = session_id.save_entities - else: # str / None - if session_id: - self.filename = session_id - if not self.filename.endswith(EXTENSION): - self.filename += EXTENSION - - system = platform.uname() - self._device_model = system.system or 'Unknown' - self._system_version = system.release or '1.0' - self._app_version = '1.0' # '0' will provoke error - self._lang_code = 'en' - self._system_lang_code = self.lang_code - self._report_errors = True - self.save_entities = True - self._flood_sleep_threshold = 60 + if session_id: + self.filename = session_id + if not self.filename.endswith(EXTENSION): + self.filename += EXTENSION self.id = struct.unpack('q', os.urandom(8))[0] - self._sequence = 0 - self.time_offset = 0 - self._last_msg_id = 0 # Long - self.salt = 0 # Long # Cross-thread safety self._seq_no_lock = Lock() self._msg_id_lock = Lock() self._db_lock = RLock() - # These values will be saved - self._dc_id = 0 - self._server_address = None - self._port = None - self._auth_key = None - # Migrating from .json -> SQL entities = self._check_migrate_json() @@ -152,7 +121,9 @@ class SQLiteSession(MemorySession): self.save() def clone(self): - return SQLiteSession(self) + cloned = super().clone() + cloned.save_entities = self.save_entities + return cloned def _check_migrate_json(self): if file_exists(self.filename): From 118d9b10e869f04c818541b1cea41ccd801fd800 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 13:20:11 +0200 Subject: [PATCH 05/15] Add more abstraction --- telethon/sessions/abstract.py | 5 ++- telethon/sessions/memory.py | 69 +++++++++++++++++++---------------- telethon/sessions/sqlite.py | 23 ++++++------ 3 files changed, 53 insertions(+), 44 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index ec4f649f..89f80b7a 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -18,8 +18,8 @@ class Session(ABC): self._report_errors = True self._flood_sleep_threshold = 60 - def clone(self): - cloned = self.__class__() + def clone(self, to_instance=None): + cloned = to_instance or self.__class__() cloned._device_model = self.device_model cloned._system_version = self.system_version cloned._app_version = self.app_version @@ -27,6 +27,7 @@ class Session(ABC): cloned._system_lang_code = self.system_lang_code cloned._report_errors = self.report_errors cloned._flood_sleep_threshold = self.flood_sleep_threshold + return cloned @abstractmethod def set_dc(self, dc_id, server_address, port): diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 09aa1fa0..92674fa6 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -111,8 +111,41 @@ class MemorySession(Session): def list_sessions(cls): raise NotImplementedError - @staticmethod - def _entities_to_rows(tlo): + def _entity_values_to_row(self, id, hash, username, phone, name): + return id, hash, username, phone, name + + def _entity_to_row(self, e): + if not isinstance(e, TLObject): + return + try: + p = utils.get_input_peer(e, allow_self=False) + marked_id = utils.get_peer_id(p) + except ValueError: + return + + if isinstance(p, (InputPeerUser, InputPeerChannel)): + if not p.access_hash: + # Some users and channels seem to be returned without + # an 'access_hash', meaning Telegram doesn't want you + # to access them. This is the reason behind ensuring + # that the 'access_hash' is non-zero. See issue #354. + # Note that this checks for zero or None, see #392. + return + else: + p_hash = p.access_hash + elif isinstance(p, InputPeerChat): + p_hash = 0 + else: + return + + username = getattr(e, 'username', None) or None + if username is not None: + username = username.lower() + phone = getattr(e, 'phone', None) + name = utils.get_display_name(e) or None + return self._entity_values_to_row(marked_id, p_hash, username, phone, name) + + def _entities_to_rows(self, tlo): if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): # This may be a list of users already for instance entities = tlo @@ -127,35 +160,9 @@ class MemorySession(Session): rows = [] # Rows to add (id, hash, username, phone, name) for e in entities: - if not isinstance(e, TLObject): - continue - try: - p = utils.get_input_peer(e, allow_self=False) - marked_id = utils.get_peer_id(p) - except ValueError: - continue - - if isinstance(p, (InputPeerUser, InputPeerChannel)): - if not p.access_hash: - # Some users and channels seem to be returned without - # an 'access_hash', meaning Telegram doesn't want you - # to access them. This is the reason behind ensuring - # that the 'access_hash' is non-zero. See issue #354. - # Note that this checks for zero or None, see #392. - continue - else: - p_hash = p.access_hash - elif isinstance(p, InputPeerChat): - p_hash = 0 - else: - continue - - username = getattr(e, 'username', None) or None - if username is not None: - username = username.lower() - phone = getattr(e, 'phone', None) - name = utils.get_display_name(e) or None - rows.append((marked_id, p_hash, username, phone, name)) + row = self._entity_to_row(e) + if row: + rows.append(row) return rows def process_entities(self, tlo): diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 64f3cbf6..0ea26ae5 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -120,8 +120,8 @@ class SQLiteSession(MemorySession): c.close() self.save() - def clone(self): - cloned = super().clone() + def clone(self, to_instance=None): + cloned = super().clone(to_instance) cloned.save_entities = self.save_entities return cloned @@ -180,9 +180,7 @@ class SQLiteSession(MemorySession): # Data from sessions should be kept as properties # not to fetch the database every time we need it def set_dc(self, dc_id, server_address, port): - self._dc_id = dc_id - self._server_address = server_address - self._port = port + super().set_dc(dc_id, server_address, port) self._update_session_table() # Fetch the auth_key corresponding to this data center @@ -287,16 +285,19 @@ class SQLiteSession(MemorySession): 'select id, hash from entities where phone=?', (phone,)) def get_entity_rows_by_username(self, username): - self._fetchone_entity('select id, hash from entities where username=?', - (username,)) + return self._fetchone_entity( + 'select id, hash from entities where username=?', + (username,)) def get_entity_rows_by_name(self, name): - self._fetchone_entity('select id, hash from entities where name=?', - (name,)) + return self._fetchone_entity( + 'select id, hash from entities where name=?', + (name,)) def get_entity_rows_by_id(self, id): - self._fetchone_entity('select id, hash from entities where id=?', - (id,)) + return self._fetchone_entity( + 'select id, hash from entities where id=?', + (id,)) # File processing From 07c2fc50ec3f841b9366777bbd0bb63164875c98 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 13:22:30 +0200 Subject: [PATCH 06/15] Add SQLAlchemy-based session --- optional-requirements.txt | 1 + setup.py | 3 +- telethon/sessions/sqlalchemy.py | 177 ++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 telethon/sessions/sqlalchemy.py diff --git a/optional-requirements.txt b/optional-requirements.txt index 55bfc014..fb83c1ab 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -1,3 +1,4 @@ cryptg pysocks hachoir3 +sqlalchemy diff --git a/setup.py b/setup.py index 00dd7446..0e052d31 100755 --- a/setup.py +++ b/setup.py @@ -151,7 +151,8 @@ def main(): ]), install_requires=['pyaes', 'rsa'], extras_require={ - 'cryptg': ['cryptg'] + 'cryptg': ['cryptg'], + 'sqlalchemy': ['sqlalchemy'] } ) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py new file mode 100644 index 00000000..94f24a1e --- /dev/null +++ b/telethon/sessions/sqlalchemy.py @@ -0,0 +1,177 @@ +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, String, Integer, Blob, orm +import sqlalchemy as sql + +from ..tl.types import InputPhoto, InputDocument + +from .memory import MemorySession, _SentFileType + +Base = declarative_base() +LATEST_VERSION = 1 + + +class DBVersion(Base): + __tablename__ = "version" + version = Column(Integer, primary_key=True) + + +class DBSession(Base): + __tablename__ = "sessions" + + session_id = Column(String, primary_key=True) + dc_id = Column(Integer, primary_key=True) + server_address = Column(String) + port = Column(Integer) + auth_key = Column(Blob) + + +class DBEntity(Base): + __tablename__ = "entities" + + session_id = Column(String, primary_key=True) + id = Column(Integer, primary_key=True) + hash = Column(Integer, nullable=False) + username = Column(String) + phone = Column(Integer) + name = Column(String) + + +class DBSentFile(Base): + __tablename__ = "sent_files" + + session_id = Column(String, primary_key=True) + md5_digest = Column(Blob, primary_key=True) + file_size = Column(Integer, primary_key=True) + type = Column(Integer, primary_key=True) + id = Column(Integer) + hash = Column(Integer) + + +class AlchemySessionContainer: + def __init__(self, database): + if not isinstance(database, sql.Engine): + database = sql.create_engine(database) + + self.db_engine = database + db_factory = orm.sessionmaker(bind=self.db_engine) + self.db = orm.scoping.scoped_session(db_factory) + + if not self.db_engine.dialect.has_table(self.db_engine, + DBVersion.__tablename__): + Base.metadata.create_all(bind=self.db_engine) + self.db.add(DBVersion(version=LATEST_VERSION)) + self.db.commit() + else: + self.check_and_upgrade_database() + + DBVersion.query = self.db.query_property() + DBSession.query = self.db.query_property() + DBEntity.query = self.db.query_property() + DBSentFile.query = self.db.query_property() + + def check_and_upgrade_database(self): + row = DBVersion.query.get() + version = row.version if row else 1 + if version == LATEST_VERSION: + return + + DBVersion.query.delete() + + # Implement table schema updates here and increase version + + self.db.add(DBVersion(version=version)) + self.db.commit() + + def new_session(self, session_id): + return AlchemySession(self, session_id) + + def list_sessions(self): + return + + def save(self): + self.db.commit() + + +class AlchemySession(MemorySession): + def __init__(self, container, session_id): + super().__init__() + self.container = container + self.db = container.db + self.session_id = session_id + + def clone(self, to_instance=None): + cloned = to_instance or self.__class__(self.container, self.session_id) + return super().clone(cloned) + + def set_dc(self, dc_id, server_address, port): + super().set_dc(dc_id, server_address, port) + + def _update_session_table(self): + self.db.query(DBSession).filter( + DBSession.session_id == self.session_id).delete() + new = DBSession(session_id=self.session_id, dc_id=self._dc_id, + server_address=self._server_address, port=self._port, + auth_key=self._auth_key.key if self._auth_key else b'') + self.db.merge(new) + + def _db_query(self, dbclass, *args): + return self.db.query(dbclass).filter( + dbclass.session_id == self.session_id, + *args) + + def save(self): + self.container.save() + + def close(self): + # Nothing to do here, connection is managed by AlchemySessionContainer. + pass + + def delete(self): + self._db_query(DBSession).delete() + self._db_query(DBEntity).delete() + self._db_query(DBSentFile).delete() + + def _entity_values_to_row(self, id, hash, username, phone, name): + return DBEntity(session_id=self.session_id, id=id, hash=hash, + username=username, phone=phone, name=name) + + def process_entities(self, tlo): + rows = self._entities_to_rows(tlo) + if not rows: + return + + self.db.add_all(rows) + self.save() + + def get_entity_rows_by_phone(self, key): + row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none() + return row.id, row.hash if row else None + + def get_entity_rows_by_username(self, key): + row = self._db_query(DBEntity, DBEntity.username == key).one_or_none() + return row.id, row.hash if row else None + + def get_entity_rows_by_name(self, key): + row = self._db_query(DBEntity, DBEntity.name == key).one_or_none() + return row.id, row.hash if row else None + + def get_entity_rows_by_id(self, key): + row = self._db_query(DBEntity, DBEntity.id == key).one_or_none() + return row.id, row.hash if row else None + + def get_file(self, md5_digest, file_size, cls): + row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest, + DBSentFile.file_size == file_size, + DBSentFile.type == _SentFileType.from_type( + cls).value).one_or_none() + return row.id, row.hash if row else None + + def cache_file(self, md5_digest, file_size, instance): + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError('Cannot cache %s instance' % type(instance)) + + self.db.merge( + DBSentFile(session_id=self.session_id, md5_digest=md5_digest, + type=_SentFileType.from_type(type(instance)).value, + id=instance.id, hash=instance.access_hash)) + self.save() From 03d4ab37657123e4ba3258ccf1867e645630b111 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 13:25:40 +0200 Subject: [PATCH 07/15] Fix create_engine check --- telethon/sessions/sqlalchemy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index 94f24a1e..782f811f 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -49,7 +49,7 @@ class DBSentFile(Base): class AlchemySessionContainer: def __init__(self, database): - if not isinstance(database, sql.Engine): + if isinstance(database, str): database = sql.create_engine(database) self.db_engine = database From e1d7cc541f878a22a46dbd3f8cc2d0dd5115f16d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 18:23:12 +0200 Subject: [PATCH 08/15] Add setters for non-persistent values that apps might change --- telethon/sessions/abstract.py | 32 ++++++++++++++++++++++++++++++++ telethon/sessions/memory.py | 32 -------------------------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index 89f80b7a..dd1541ab 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -100,26 +100,50 @@ class Session(ABC): def device_model(self): return self._device_model + @device_model.setter + def device_model(self, value): + self._device_model = value + @property def system_version(self): return self._system_version + @system_version.setter + def system_version(self, value): + self._system_version = value + @property def app_version(self): return self._app_version + @app_version.setter + def app_version(self, value): + self._app_version = value + @property def lang_code(self): return self._lang_code + @lang_code.setter + def lang_code(self, value): + self._lang_code = value + @property def system_lang_code(self): return self._system_lang_code + @system_lang_code.setter + def system_lang_code(self, value): + self._system_lang_code = value + @property def report_errors(self): return self._report_errors + @report_errors.setter + def report_errors(self, value): + self._report_errors = value + @property def time_offset(self): return self._time_offset @@ -132,10 +156,18 @@ class Session(ABC): def flood_sleep_threshold(self): return self._flood_sleep_threshold + @flood_sleep_threshold.setter + def flood_sleep_threshold(self, value): + self._flood_sleep_threshold = value + @property def sequence(self): return self._sequence + @sequence.setter + def sequence(self, value): + self._sequence = value + def get_new_msg_id(self): """Generates a new unique message ID based on the current time (in ms) since epoch""" diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 92674fa6..71d6e551 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -66,38 +66,6 @@ class MemorySession(Session): def salt(self, value): self._salt = value - @property - def device_model(self): - return self._device_model - - @property - def system_version(self): - return self._system_version - - @property - def app_version(self): - return self._app_version - - @property - def lang_code(self): - return self._lang_code - - @property - def system_lang_code(self): - return self._system_lang_code - - @property - def report_errors(self): - return self._report_errors - - @property - def sequence(self): - return self._sequence - - @property - def flood_sleep_threshold(self): - return self._flood_sleep_threshold - def close(self): pass From dc2229fdba10f0140a7e159737091d82852cdef3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 18:39:04 +0200 Subject: [PATCH 09/15] Move salt and ID to base session and remove unused imports --- telethon/sessions/abstract.py | 23 +++++++++++++---------- telethon/sessions/memory.py | 10 +--------- telethon/sessions/sqlite.py | 12 +----------- 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index dd1541ab..d92e0754 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -1,13 +1,18 @@ from abc import ABC, abstractmethod import time import platform +import struct +import os class Session(ABC): def __init__(self): + self.id = struct.unpack('q', os.urandom(8))[0] + self._sequence = 0 self._last_msg_id = 0 self._time_offset = 0 + self._salt = 0 system = platform.uname() self._device_model = system.system or 'Unknown' @@ -53,16 +58,6 @@ class Session(ABC): def auth_key(self, value): raise NotImplementedError - @property - @abstractmethod - def salt(self): - raise NotImplementedError - - @salt.setter - @abstractmethod - def salt(self, value): - raise NotImplementedError - @abstractmethod def close(self): raise NotImplementedError @@ -96,6 +91,14 @@ class Session(ABC): def get_file(self, md5_digest, file_size, cls): raise NotImplementedError + @property + def salt(self): + return self._salt + + @salt.setter + def salt(self, value): + self._salt = value + @property def device_model(self): return self._device_model diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 71d6e551..7ab31b21 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -28,10 +28,10 @@ class _SentFileType(Enum): class MemorySession(Session): def __init__(self): super().__init__() + self._dc_id = None self._server_address = None self._port = None - self._salt = None self._auth_key = None self._files = {} @@ -58,14 +58,6 @@ class MemorySession(Session): def auth_key(self, value): self._auth_key = value - @property - def salt(self): - return self._salt - - @salt.setter - def salt(self, value): - self._salt = value - def close(self): pass diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 0ea26ae5..0423d88a 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -1,21 +1,13 @@ import json import os -import platform import sqlite3 -import struct -import time from base64 import b64decode from os.path import isfile as file_exists from threading import Lock, RLock -from .. import utils -from .abstract import Session from .memory import MemorySession, _SentFileType from ..crypto import AuthKey -from ..tl import TLObject from ..tl.types import ( - PeerUser, PeerChat, PeerChannel, - InputPeerUser, InputPeerChat, InputPeerChannel, InputPhoto, InputDocument ) @@ -47,8 +39,6 @@ class SQLiteSession(MemorySession): if not self.filename.endswith(EXTENSION): self.filename += EXTENSION - self.id = struct.unpack('q', os.urandom(8))[0] - # Cross-thread safety self._seq_no_lock = Lock() self._msg_id_lock = Lock() @@ -193,7 +183,7 @@ class SQLiteSession(MemorySession): self._auth_key = None c.close() - @Session.auth_key.setter + @MemorySession.auth_key.setter def auth_key(self, value): self._auth_key = value self._update_session_table() From c1a8896faa43dfa798fc9223283eae42dbf5afd6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 20:14:11 +0200 Subject: [PATCH 10/15] Fix SQLAlchemy implementation --- telethon/sessions/__init__.py | 1 + telethon/sessions/sqlalchemy.py | 197 ++++++++++++++++++-------------- 2 files changed, 114 insertions(+), 84 deletions(-) diff --git a/telethon/sessions/__init__.py b/telethon/sessions/__init__.py index af3423f3..a487a4bd 100644 --- a/telethon/sessions/__init__.py +++ b/telethon/sessions/__init__.py @@ -1,3 +1,4 @@ from .abstract import Session from .memory import MemorySession from .sqlite import SQLiteSession +from .sqlalchemy import AlchemySessionContainer, AlchemySession diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index 782f811f..aa618e4c 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -1,85 +1,95 @@ from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, Integer, Blob, orm +from sqlalchemy import Column, String, Integer, BLOB, orm import sqlalchemy as sql +from ..crypto import AuthKey from ..tl.types import InputPhoto, InputDocument from .memory import MemorySession, _SentFileType -Base = declarative_base() LATEST_VERSION = 1 -class DBVersion(Base): - __tablename__ = "version" - version = Column(Integer, primary_key=True) - - -class DBSession(Base): - __tablename__ = "sessions" - - session_id = Column(String, primary_key=True) - dc_id = Column(Integer, primary_key=True) - server_address = Column(String) - port = Column(Integer) - auth_key = Column(Blob) - - -class DBEntity(Base): - __tablename__ = "entities" - - session_id = Column(String, primary_key=True) - id = Column(Integer, primary_key=True) - hash = Column(Integer, nullable=False) - username = Column(String) - phone = Column(Integer) - name = Column(String) - - -class DBSentFile(Base): - __tablename__ = "sent_files" - - session_id = Column(String, primary_key=True) - md5_digest = Column(Blob, primary_key=True) - file_size = Column(Integer, primary_key=True) - type = Column(Integer, primary_key=True) - id = Column(Integer) - hash = Column(Integer) - - class AlchemySessionContainer: - def __init__(self, database): - if isinstance(database, str): - database = sql.create_engine(database) + def __init__(self, engine=None, session=None, table_prefix="", + table_base=None, manage_tables=True): + if isinstance(engine, str): + engine = sql.create_engine(engine) - self.db_engine = database - db_factory = orm.sessionmaker(bind=self.db_engine) - self.db = orm.scoping.scoped_session(db_factory) - - if not self.db_engine.dialect.has_table(self.db_engine, - DBVersion.__tablename__): - Base.metadata.create_all(bind=self.db_engine) - self.db.add(DBVersion(version=LATEST_VERSION)) - self.db.commit() + self.db_engine = engine + if not session: + db_factory = orm.sessionmaker(bind=self.db_engine) + self.db = orm.scoping.scoped_session(db_factory) else: - self.check_and_upgrade_database() + self.db = session - DBVersion.query = self.db.query_property() - DBSession.query = self.db.query_property() - DBEntity.query = self.db.query_property() - DBSentFile.query = self.db.query_property() + table_base = table_base or declarative_base() + (self.Version, self.Session, self.Entity, + self.SentFile) = self.create_table_classes(self.db, table_prefix, + table_base) + + if manage_tables: + table_base.metadata.bind = self.db_engine + if not self.db_engine.dialect.has_table(self.db_engine, + self.Version.__tablename__): + table_base.metadata.create_all() + self.db.add(self.Version(version=LATEST_VERSION)) + self.db.commit() + else: + self.check_and_upgrade_database() + + @staticmethod + def create_table_classes(db, prefix, Base): + class Version(Base): + query = db.query_property() + __tablename__ = "{prefix}version".format(prefix=prefix) + version = Column(Integer, primary_key=True) + + class Session(Base): + query = db.query_property() + __tablename__ = "{prefix}sessions".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + dc_id = Column(Integer, primary_key=True) + server_address = Column(String) + port = Column(Integer) + auth_key = Column(BLOB) + + class Entity(Base): + query = db.query_property() + __tablename__ = "{prefix}entities".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + id = Column(Integer, primary_key=True) + hash = Column(Integer, nullable=False) + username = Column(String) + phone = Column(Integer) + name = Column(String) + + class SentFile(Base): + query = db.query_property() + __tablename__ = "{prefix}sent_files".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + md5_digest = Column(BLOB, primary_key=True) + file_size = Column(Integer, primary_key=True) + type = Column(Integer, primary_key=True) + id = Column(Integer) + hash = Column(Integer) + + return Version, Session, Entity, SentFile def check_and_upgrade_database(self): - row = DBVersion.query.get() - version = row.version if row else 1 + row = self.Version.query.all() + version = row[0].version if row else 1 if version == LATEST_VERSION: return - DBVersion.query.delete() + self.Version.query.delete() # Implement table schema updates here and increase version - self.db.add(DBVersion(version=version)) + self.db.add(self.Version(version=version)) self.db.commit() def new_session(self, session_id): @@ -97,7 +107,20 @@ class AlchemySession(MemorySession): super().__init__() self.container = container self.db = container.db + self.Version, self.Session, self.Entity, self.SentFile = ( + container.Version, container.Session, container.Entity, + container.SentFile) self.session_id = session_id + self.load_session() + + def load_session(self): + sessions = self._db_query(self.Session).all() + session = sessions[0] if sessions else None + if session: + self._dc_id = session.dc_id + self._server_address = session.server_address + self._port = session.port + self._auth_key = AuthKey(data=session.auth_key) def clone(self, to_instance=None): cloned = to_instance or self.__class__(self.container, self.session_id) @@ -107,17 +130,18 @@ class AlchemySession(MemorySession): super().set_dc(dc_id, server_address, port) def _update_session_table(self): - self.db.query(DBSession).filter( - DBSession.session_id == self.session_id).delete() - new = DBSession(session_id=self.session_id, dc_id=self._dc_id, - server_address=self._server_address, port=self._port, - auth_key=self._auth_key.key if self._auth_key else b'') + self.Session.query.filter( + self.Session.session_id == self.session_id).delete() + new = self.Session(session_id=self.session_id, dc_id=self._dc_id, + server_address=self._server_address, + port=self._port, + auth_key=(self._auth_key.key + if self._auth_key else b'')) self.db.merge(new) def _db_query(self, dbclass, *args): - return self.db.query(dbclass).filter( - dbclass.session_id == self.session_id, - *args) + return dbclass.query.filter(dbclass.session_id == self.session_id, + *args) def save(self): self.container.save() @@ -127,42 +151,47 @@ class AlchemySession(MemorySession): pass def delete(self): - self._db_query(DBSession).delete() - self._db_query(DBEntity).delete() - self._db_query(DBSentFile).delete() + self._db_query(self.Session).delete() + self._db_query(self.Entity).delete() + self._db_query(self.SentFile).delete() def _entity_values_to_row(self, id, hash, username, phone, name): - return DBEntity(session_id=self.session_id, id=id, hash=hash, - username=username, phone=phone, name=name) + return self.Entity(session_id=self.session_id, id=id, hash=hash, + username=username, phone=phone, name=name) def process_entities(self, tlo): rows = self._entities_to_rows(tlo) if not rows: return - self.db.add_all(rows) + for row in rows: + self.db.merge(row) self.save() def get_entity_rows_by_phone(self, key): - row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.phone == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_username(self, key): - row = self._db_query(DBEntity, DBEntity.username == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.username == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_name(self, key): - row = self._db_query(DBEntity, DBEntity.name == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.name == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_id(self, key): - row = self._db_query(DBEntity, DBEntity.id == key).one_or_none() + row = self._db_query(self.Entity, self.Entity.id == key).one_or_none() return row.id, row.hash if row else None def get_file(self, md5_digest, file_size, cls): - row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest, - DBSentFile.file_size == file_size, - DBSentFile.type == _SentFileType.from_type( + row = self._db_query(self.SentFile, + self.SentFile.md5_digest == md5_digest, + self.SentFile.file_size == file_size, + self.SentFile.type == _SentFileType.from_type( cls).value).one_or_none() return row.id, row.hash if row else None @@ -171,7 +200,7 @@ class AlchemySession(MemorySession): raise TypeError('Cannot cache %s instance' % type(instance)) self.db.merge( - DBSentFile(session_id=self.session_id, md5_digest=md5_digest, - type=_SentFileType.from_type(type(instance)).value, - id=instance.id, hash=instance.access_hash)) + self.SentFile(session_id=self.session_id, md5_digest=md5_digest, + type=_SentFileType.from_type(type(instance)).value, + id=instance.id, hash=instance.access_hash)) self.save() From f805914c80e034058789ec6d7fef43dd4762e051 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 20:40:03 +0200 Subject: [PATCH 11/15] Handle SQLAlchemy import errors --- telethon/sessions/sqlalchemy.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index aa618e4c..0b028c02 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -1,6 +1,10 @@ -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, Integer, BLOB, orm -import sqlalchemy as sql +try: + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import Column, String, Integer, BLOB, orm + import sqlalchemy as sql +except ImportError: + sql = None + pass from ..crypto import AuthKey from ..tl.types import InputPhoto, InputDocument @@ -13,6 +17,8 @@ LATEST_VERSION = 1 class AlchemySessionContainer: def __init__(self, engine=None, session=None, table_prefix="", table_base=None, manage_tables=True): + if not sql: + raise ImportError("SQLAlchemy not imported") if isinstance(engine, str): engine = sql.create_engine(engine) From 5e88b21aa9ce5bd51ecd7fcd3627797c58d996ee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 20:42:51 +0200 Subject: [PATCH 12/15] Use single quotes --- telethon/sessions/sqlalchemy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index 0b028c02..0fd76fe3 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -15,10 +15,10 @@ LATEST_VERSION = 1 class AlchemySessionContainer: - def __init__(self, engine=None, session=None, table_prefix="", + def __init__(self, engine=None, session=None, table_prefix='', table_base=None, manage_tables=True): if not sql: - raise ImportError("SQLAlchemy not imported") + raise ImportError('SQLAlchemy not imported') if isinstance(engine, str): engine = sql.create_engine(engine) @@ -48,12 +48,12 @@ class AlchemySessionContainer: def create_table_classes(db, prefix, Base): class Version(Base): query = db.query_property() - __tablename__ = "{prefix}version".format(prefix=prefix) + __tablename__ = '{prefix}version'.format(prefix=prefix) version = Column(Integer, primary_key=True) class Session(Base): query = db.query_property() - __tablename__ = "{prefix}sessions".format(prefix=prefix) + __tablename__ = '{prefix}sessions'.format(prefix=prefix) session_id = Column(String, primary_key=True) dc_id = Column(Integer, primary_key=True) @@ -63,7 +63,7 @@ class AlchemySessionContainer: class Entity(Base): query = db.query_property() - __tablename__ = "{prefix}entities".format(prefix=prefix) + __tablename__ = '{prefix}entities'.format(prefix=prefix) session_id = Column(String, primary_key=True) id = Column(Integer, primary_key=True) @@ -74,7 +74,7 @@ class AlchemySessionContainer: class SentFile(Base): query = db.query_property() - __tablename__ = "{prefix}sent_files".format(prefix=prefix) + __tablename__ = '{prefix}sent_files'.format(prefix=prefix) session_id = Column(String, primary_key=True) md5_digest = Column(BLOB, primary_key=True) From 47cdcda9e2b8c17d3f4e842509a67fcf76b2548c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 21:05:09 +0200 Subject: [PATCH 13/15] Move device info out of Session --- telethon/sessions/abstract.py | 53 -------------------------------- telethon/telegram_bare_client.py | 30 ++++++++++-------- 2 files changed, 18 insertions(+), 65 deletions(-) diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index d92e0754..647a87c1 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod import time -import platform import struct import os @@ -13,23 +12,11 @@ class Session(ABC): self._last_msg_id = 0 self._time_offset = 0 self._salt = 0 - - system = platform.uname() - self._device_model = system.system or 'Unknown' - self._system_version = system.release or '1.0' - self._app_version = '1.0' - self._lang_code = 'en' - self._system_lang_code = self.lang_code self._report_errors = True self._flood_sleep_threshold = 60 def clone(self, to_instance=None): cloned = to_instance or self.__class__() - cloned._device_model = self.device_model - cloned._system_version = self.system_version - cloned._app_version = self.app_version - cloned._lang_code = self.lang_code - cloned._system_lang_code = self.system_lang_code cloned._report_errors = self.report_errors cloned._flood_sleep_threshold = self.flood_sleep_threshold return cloned @@ -99,46 +86,6 @@ class Session(ABC): def salt(self, value): self._salt = value - @property - def device_model(self): - return self._device_model - - @device_model.setter - def device_model(self, value): - self._device_model = value - - @property - def system_version(self): - return self._system_version - - @system_version.setter - def system_version(self, value): - self._system_version = value - - @property - def app_version(self): - return self._app_version - - @app_version.setter - def app_version(self, value): - self._app_version = value - - @property - def lang_code(self): - return self._lang_code - - @lang_code.setter - def lang_code(self, value): - self._lang_code = value - - @property - def system_lang_code(self): - return self._system_lang_code - - @system_lang_code.setter - def system_lang_code(self, value): - self._system_lang_code = value - @property def report_errors(self): return self._report_errors diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 3a5b2bd0..bf33a7dc 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -1,11 +1,11 @@ import logging import os +import platform import threading from datetime import timedelta, datetime from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Lock from time import sleep - from . import version, utils from .crypto import rsa from .errors import ( @@ -73,7 +73,12 @@ class TelegramBareClient: update_workers=None, spawn_read_thread=False, timeout=timedelta(seconds=5), - **kwargs): + loop=None, + device_model=None, + system_version=None, + app_version=None, + lang_code='en', + system_lang_code='en'): """Refer to TelegramClient.__init__ for docs on this method""" if not api_id or not api_hash: raise ValueError( @@ -125,11 +130,12 @@ class TelegramBareClient: self.updates = UpdateState(workers=update_workers) # Used on connection - the user may modify these and reconnect - kwargs['app_version'] = kwargs.get('app_version', self.__version__) - for name, value in kwargs.items(): - if not hasattr(self.session, name): - raise ValueError('Unknown named parameter', name) - setattr(self.session, name, value) + system = platform.uname() + self.device_model = device_model or system.system or 'Unknown' + self.system_version = system_version or system.release or '1.0' + self.app_version = app_version or self.__version__ + self.lang_code = lang_code + self.system_lang_code = system_lang_code # Despite the state of the real connection, keep track of whether # the user has explicitly called .connect() or .disconnect() here. @@ -233,11 +239,11 @@ class TelegramBareClient: """Wraps query around InvokeWithLayerRequest(InitConnectionRequest())""" return InvokeWithLayerRequest(LAYER, InitConnectionRequest( api_id=self.api_id, - device_model=self.session.device_model, - system_version=self.session.system_version, - app_version=self.session.app_version, - lang_code=self.session.lang_code, - system_lang_code=self.session.system_lang_code, + device_model=self.device_model, + system_version=self.system_version, + app_version=self.app_version, + lang_code=self.lang_code, + system_lang_code=self.system_lang_code, lang_pack='', # "langPacks are for official apps only" query=query )) From 290afd85fc5f91688e4e18355ac207e8a2c2f1ec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 21:58:16 +0200 Subject: [PATCH 14/15] Fix AlchemySession session table updating --- telethon/sessions/sqlalchemy.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index 0fd76fe3..933f44c2 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -117,9 +117,9 @@ class AlchemySession(MemorySession): container.Version, container.Session, container.Entity, container.SentFile) self.session_id = session_id - self.load_session() + self._load_session() - def load_session(self): + def _load_session(self): sessions = self._db_query(self.Session).all() session = sessions[0] if sessions else None if session: @@ -134,6 +134,19 @@ class AlchemySession(MemorySession): def set_dc(self, dc_id, server_address, port): super().set_dc(dc_id, server_address, port) + self._update_session_table() + + sessions = self._db_query(self.Session).all() + session = sessions[0] if sessions else None + if session and session.auth_key: + self._auth_key = AuthKey(data=session.auth_key) + else: + self._auth_key = None + + @MemorySession.auth_key.setter + def auth_key(self, value): + self._auth_key = value + self._update_session_table() def _update_session_table(self): self.Session.query.filter( From 9bf5cb7ed8d15021bd73e5051a7dd01e7af8a33b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Mar 2018 12:28:18 +0200 Subject: [PATCH 15/15] Add new sessions docs --- readthedocs/extra/advanced-usage/sessions.rst | 96 +++++++++++++++---- 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/readthedocs/extra/advanced-usage/sessions.rst b/readthedocs/extra/advanced-usage/sessions.rst index fca7828e..ad824837 100644 --- a/readthedocs/extra/advanced-usage/sessions.rst +++ b/readthedocs/extra/advanced-usage/sessions.rst @@ -25,29 +25,89 @@ file, so that you can quickly access them by username or phone number. If you're not going to work with updates, or don't need to cache the ``access_hash`` associated with the entities' ID, you can disable this -by setting ``client.session.save_entities = False``, or pass it as a -parameter to the ``TelegramClient``. +by setting ``client.session.save_entities = False``. -If you don't want to save the files as a database, you can also create -your custom ``Session`` subclass and override the ``.save()`` and ``.load()`` -methods. For example, you could save it on a database: +Custom Session Storage +---------------------- + +If you don't want to use the default SQLite session storage, you can also use +one of the other implementations or implement your own storage. + +To use a custom session storage, simply pass the custom session instance to +``TelegramClient`` instead of the session name. + +Currently, there are three implementations of the abstract ``Session`` class: +* MemorySession. Stores session data in Python variables. +* SQLiteSession, the default. Stores sessions in their own SQLite databases. +* AlchemySession. Stores all sessions in a single database via SQLAlchemy. + +Using AlchemySession +~~~~~~~~~~~~~~~~~~~~ +The AlchemySession implementation can store multiple Sessions in the same +database, but to do this, each session instance needs to have access to the +same models and database session. + +To get started, you need to create an ``AlchemySessionContainer`` which will +contain that shared data. The simplest way to use ``AlchemySessionContainer`` +is to simply pass it the database URL: .. code-block:: python - class DatabaseSession(Session): - def save(): - # serialize relevant data to the database + container = AlchemySessionContainer('mysql://user:pass@localhost/telethon') - def load(): - # load relevant data to the database +If you already have SQLAlchemy set up for your own project, you can also pass +the engine separately: + + .. code-block:: python + + my_sqlalchemy_engine = sqlalchemy.create_engine('...') + container = AlchemySessionContainer(engine=my_sqlalchemy_engine) + +By default, the session container will manage table creation/schema updates/etc +automatically. If you want to manage everything yourself, you can pass your +SQLAlchemy Session and ``declarative_base`` instances and set ``manage_tables`` +to ``False``: + + .. code-block:: python + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import orm + import sqlalchemy + + ... + + session_factory = orm.sessionmaker(bind=my_sqlalchemy_engine) + session = session_factory() + my_base = declarative_base() + + ... + + container = AlchemySessionContainer(session=session, table_base=my_base, manage_tables=False) + +You always need to provide either ``engine`` or ``session`` to the container. +If you set ``manage_tables=False`` and provide a ``session``, ``engine`` is not +needed. In any other case, ``engine`` is always required. + +After you have your ``AlchemySessionContainer`` instance created, you can +create new sessions by calling ``new_session``: + + .. code-block:: python + + session = container.new_session('some session id') + client = TelegramClient(session) + +where ``some session id`` is an unique identifier for the session. + +Creating your own storage +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The easiest way to create your own implementation is to use MemorySession as +the base and check out how ``SQLiteSession`` or ``AlchemySession`` work. You +can find the relevant Python files under the ``sessions`` directory. -You should read the ````session.py```` source file to know what "relevant -data" you need to keep track of. - - -Sessions and Heroku -------------------- +SQLite Sessions and Heroku +-------------------------- You probably have a newer version of SQLite installed (>= 3.8.2). Heroku uses SQLite 3.7.9 which does not support ``WITHOUT ROWID``. So, if you generated @@ -59,8 +119,8 @@ session file on your Heroku dyno itself. The most complicated is creating a custom buildpack to install SQLite >= 3.8.2. -Generating a Session File on a Heroku Dyno -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Generating a SQLite Session File on a Heroku Dyno +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. note:: Due to Heroku's ephemeral filesystem all dynamically generated