diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 6c258c9a..d4f19b8d 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -92,7 +92,7 @@ class TelegramBareClient: # Determine what session object we have if isinstance(session, str) or session is None: - session = Session.try_load_or_create_new(session) + session = Session(session) elif not isinstance(session, Session): raise ValueError( 'The given session must be a str or a Session instance.' diff --git a/telethon/tl/session.py b/telethon/tl/session.py index e530cc83..e9885a56 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -1,15 +1,19 @@ import json import os import platform +import sqlite3 import struct import time -from base64 import b64encode, b64decode +from base64 import b64decode from os.path import isfile as file_exists from threading import Lock from .entity_database import EntityDatabase from .. import helpers +EXTENSION = '.session' +CURRENT_VERSION = 1 # database version + class Session: """This session contains the required information to login into your @@ -25,6 +29,7 @@ class Session: those required to init a connection will be copied. """ # These values will NOT be saved + self.filename = ':memory:' if isinstance(session_user_id, Session): self.session_user_id = None @@ -41,7 +46,10 @@ class Session: self.flood_sleep_threshold = session.flood_sleep_threshold else: # str / None - self.session_user_id = session_user_id + if session_user_id: + self.filename = session_user_id + if not self.filename.endswith(EXTENSION): + self.filename += EXTENSION system = platform.uname() self.device_model = system.system if system.system else 'Unknown' @@ -54,49 +62,172 @@ class Session: self.save_entities = True self.flood_sleep_threshold = 60 + # These values will be saved + self._server_address = None + self._port = None + self._auth_key = None + self._layer = 0 + self._salt = 0 # Signed long + self.entities = EntityDatabase() # Known and cached entities + # Cross-thread safety self._seq_no_lock = Lock() self._msg_id_lock = Lock() - self._save_lock = Lock() + self._db_lock = Lock() + + # Migrating from .json -> SQL + self._check_migrate_json() + + self._conn = sqlite3.connect(self.filename, check_same_thread=False) + c = self._conn.cursor() + c.execute("select name from sqlite_master " + "where type='table' and name='version'") + if c.fetchone(): + # Tables already exist, check for the version + c.execute("select version from version") + version = c.fetchone()[0] + if version != CURRENT_VERSION: + self._upgrade_database(old=version) + self.save() + + # These values will be saved + c.execute('select * from sessions') + self._server_address, self._port, key, \ + self._layer, self._salt = c.fetchone() + + from ..crypto import AuthKey + self._auth_key = AuthKey(data=key) + c.close() + else: + # Tables don't exist, create new ones + c.execute("create table version (version integer)") + c.execute( + """create table sessions ( + server_address text, + port integer, + auth_key blob, + layer integer, + salt integer + )""" + ) + c.execute( + """create table entities ( + id integer, + hash integer, + username text, + phone integer, + name text + )""" + ) + c.execute("insert into version values (1)") + c.close() + self.save() self.id = helpers.generate_random_long(signed=True) self._sequence = 0 self.time_offset = 0 self._last_msg_id = 0 # Long - # These values will be saved - self.server_address = None - self.port = None - self.auth_key = None - self.layer = 0 - self.salt = 0 # Signed long - self.entities = EntityDatabase() # Known and cached entities + def _check_migrate_json(self): + if file_exists(self.filename): + try: + with open(self.filename, encoding='utf-8') as f: + data = json.load(f) + self._port = data.get('port', self._port) + self._salt = data.get('salt', self._salt) + # Keep while migrating from unsigned to signed salt + if self._salt > 0: + self._salt = struct.unpack( + 'q', struct.pack('Q', self._salt))[0] + + self._layer = data.get('layer', self._layer) + self._server_address = \ + data.get('server_address', self._server_address) + + from ..crypto import AuthKey + if data.get('auth_key_data', None) is not None: + key = b64decode(data['auth_key_data']) + self._auth_key = AuthKey(data=key) + + self.entities = EntityDatabase(data.get('entities', [])) + self.delete() # Delete JSON file to create database + except (UnicodeDecodeError, json.decoder.JSONDecodeError): + pass + + def _upgrade_database(self, old): + pass + + # Data from sessions should be kept as properties + # not to fetch the database every time we need it + @property + def server_address(self): + return self._server_address + + @server_address.setter + def server_address(self, value): + self._server_address = value + self._update_session_table() + + @property + def port(self): + return self._port + + @port.setter + def port(self, value): + self._port = value + self._update_session_table() + + @property + def auth_key(self): + return self._auth_key + + @auth_key.setter + def auth_key(self, value): + self._auth_key = value + self._update_session_table() + + @property + def layer(self): + return self._layer + + @layer.setter + def layer(self, value): + self._layer = value + self._update_session_table() + + @property + def salt(self): + return self._salt + + @salt.setter + def salt(self, value): + self._salt = value + self._update_session_table() + + def _update_session_table(self): + with self._db_lock: + c = self._conn.cursor() + c.execute('delete from sessions') + c.execute('insert into sessions values (?,?,?,?,?)', ( + self._server_address, + self._port, + self._auth_key.key if self._auth_key else b'', + self._layer, + self._salt + )) + c.close() def save(self): """Saves the current session object as session_user_id.session""" - if not self.session_user_id or self._save_lock.locked(): - return - - with self._save_lock: - with open('{}.session'.format(self.session_user_id), 'w') as file: - out_dict = { - 'port': self.port, - 'salt': self.salt, - 'layer': self.layer, - 'server_address': self.server_address, - 'auth_key_data': - b64encode(self.auth_key.key).decode('ascii') - if self.auth_key else None - } - if self.save_entities: - out_dict['entities'] = self.entities.get_input_list() - - json.dump(out_dict, file) + with self._db_lock: + self._conn.commit() def delete(self): """Deletes the current session file""" + if self.filename == ':memory:': + return True try: - os.remove('{}.session'.format(self.session_user_id)) + os.remove(self.filename) return True except OSError: return False @@ -107,48 +238,7 @@ class Session: using this client and never logged out """ return [os.path.splitext(os.path.basename(f))[0] - for f in os.listdir('.') if f.endswith('.session')] - - @staticmethod - def try_load_or_create_new(session_user_id): - """Loads a saved session_user_id.session or creates a new one. - If session_user_id=None, later .save()'s will have no effect. - """ - if session_user_id is None: - return Session(None) - else: - path = '{}.session'.format(session_user_id) - result = Session(session_user_id) - if not file_exists(path): - return result - - try: - with open(path, 'r') as file: - data = json.load(file) - result.port = data.get('port', result.port) - result.salt = data.get('salt', result.salt) - # Keep while migrating from unsigned to signed salt - if result.salt > 0: - result.salt = struct.unpack( - 'q', struct.pack('Q', result.salt))[0] - - result.layer = data.get('layer', result.layer) - result.server_address = \ - data.get('server_address', result.server_address) - - # FIXME We need to import the AuthKey here or otherwise - # we get cyclic dependencies. - from ..crypto import AuthKey - if data.get('auth_key_data', None) is not None: - key = b64decode(data['auth_key_data']) - result.auth_key = AuthKey(data=key) - - result.entities = EntityDatabase(data.get('entities', [])) - - except (json.decoder.JSONDecodeError, UnicodeDecodeError): - pass - - return result + for f in os.listdir('.') if f.endswith(EXTENSION)] def generate_sequence(self, content_related): """Thread safe method to generates the next sequence number,