Use sqlite3 instead JSON for the session files

This commit is contained in:
Lonami Exo 2017-12-26 16:45:47 +01:00
parent b11c2e885b
commit 664417b409
2 changed files with 162 additions and 72 deletions

View File

@ -92,7 +92,7 @@ class TelegramBareClient:
# Determine what session object we have # Determine what session object we have
if isinstance(session, str) or session is None: if isinstance(session, str) or session is None:
session = Session.try_load_or_create_new(session) session = Session(session)
elif not isinstance(session, Session): elif not isinstance(session, Session):
raise ValueError( raise ValueError(
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'

View File

@ -1,15 +1,19 @@
import json import json
import os import os
import platform import platform
import sqlite3
import struct import struct
import time import time
from base64 import b64encode, b64decode from base64 import b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock
from .entity_database import EntityDatabase from .entity_database import EntityDatabase
from .. import helpers from .. import helpers
EXTENSION = '.session'
CURRENT_VERSION = 1 # database version
class Session: class Session:
"""This session contains the required information to login into your """This session contains the required information to login into your
@ -25,6 +29,7 @@ class Session:
those required to init a connection will be copied. those required to init a connection will be copied.
""" """
# These values will NOT be saved # These values will NOT be saved
self.filename = ':memory:'
if isinstance(session_user_id, Session): if isinstance(session_user_id, Session):
self.session_user_id = None self.session_user_id = None
@ -41,7 +46,10 @@ class Session:
self.flood_sleep_threshold = session.flood_sleep_threshold self.flood_sleep_threshold = session.flood_sleep_threshold
else: # str / None else: # str / None
self.session_user_id = session_user_id if session_user_id:
self.filename = session_user_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
system = platform.uname() system = platform.uname()
self.device_model = system.system if system.system else 'Unknown' self.device_model = system.system if system.system else 'Unknown'
@ -54,49 +62,172 @@ class Session:
self.save_entities = True self.save_entities = True
self.flood_sleep_threshold = 60 self.flood_sleep_threshold = 60
# These values will be saved
self._server_address = None
self._port = None
self._auth_key = None
self._layer = 0
self._salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
# Cross-thread safety # Cross-thread safety
self._seq_no_lock = Lock() self._seq_no_lock = Lock()
self._msg_id_lock = Lock() self._msg_id_lock = Lock()
self._save_lock = Lock() self._db_lock = Lock()
# Migrating from .json -> SQL
self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
c = self._conn.cursor()
c.execute("select name from sqlite_master "
"where type='table' and name='version'")
if c.fetchone():
# Tables already exist, check for the version
c.execute("select version from version")
version = c.fetchone()[0]
if version != CURRENT_VERSION:
self._upgrade_database(old=version)
self.save()
# These values will be saved
c.execute('select * from sessions')
self._server_address, self._port, key, \
self._layer, self._salt = c.fetchone()
from ..crypto import AuthKey
self._auth_key = AuthKey(data=key)
c.close()
else:
# Tables don't exist, create new ones
c.execute("create table version (version integer)")
c.execute(
"""create table sessions (
server_address text,
port integer,
auth_key blob,
layer integer,
salt integer
)"""
)
c.execute(
"""create table entities (
id integer,
hash integer,
username text,
phone integer,
name text
)"""
)
c.execute("insert into version values (1)")
c.close()
self.save()
self.id = helpers.generate_random_long(signed=True) self.id = helpers.generate_random_long(signed=True)
self._sequence = 0 self._sequence = 0
self.time_offset = 0 self.time_offset = 0
self._last_msg_id = 0 # Long self._last_msg_id = 0 # Long
# These values will be saved def _check_migrate_json(self):
self.server_address = None if file_exists(self.filename):
self.port = None try:
self.auth_key = None with open(self.filename, encoding='utf-8') as f:
self.layer = 0 data = json.load(f)
self.salt = 0 # Signed long self._port = data.get('port', self._port)
self.entities = EntityDatabase() # Known and cached entities self._salt = data.get('salt', self._salt)
# Keep while migrating from unsigned to signed salt
if self._salt > 0:
self._salt = struct.unpack(
'q', struct.pack('Q', self._salt))[0]
self._layer = data.get('layer', self._layer)
self._server_address = \
data.get('server_address', self._server_address)
from ..crypto import AuthKey
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key)
self.entities = EntityDatabase(data.get('entities', []))
self.delete() # Delete JSON file to create database
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
pass
def _upgrade_database(self, old):
pass
# Data from sessions should be kept as properties
# not to fetch the database every time we need it
@property
def server_address(self):
return self._server_address
@server_address.setter
def server_address(self, value):
self._server_address = value
self._update_session_table()
@property
def port(self):
return self._port
@port.setter
def port(self, value):
self._port = value
self._update_session_table()
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@property
def layer(self):
return self._layer
@layer.setter
def layer(self, value):
self._layer = value
self._update_session_table()
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
self._update_session_table()
def _update_session_table(self):
with self._db_lock:
c = self._conn.cursor()
c.execute('delete from sessions')
c.execute('insert into sessions values (?,?,?,?,?)', (
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b'',
self._layer,
self._salt
))
c.close()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
if not self.session_user_id or self._save_lock.locked(): with self._db_lock:
return self._conn.commit()
with self._save_lock:
with open('{}.session'.format(self.session_user_id), 'w') as file:
out_dict = {
'port': self.port,
'salt': self.salt,
'layer': self.layer,
'server_address': self.server_address,
'auth_key_data':
b64encode(self.auth_key.key).decode('ascii')
if self.auth_key else None
}
if self.save_entities:
out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file)
def delete(self): def delete(self):
"""Deletes the current session file""" """Deletes the current session file"""
if self.filename == ':memory:':
return True
try: try:
os.remove('{}.session'.format(self.session_user_id)) os.remove(self.filename)
return True return True
except OSError: except OSError:
return False return False
@ -107,48 +238,7 @@ class Session:
using this client and never logged out using this client and never logged out
""" """
return [os.path.splitext(os.path.basename(f))[0] return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith('.session')] for f in os.listdir('.') if f.endswith(EXTENSION)]
@staticmethod
def try_load_or_create_new(session_user_id):
"""Loads a saved session_user_id.session or creates a new one.
If session_user_id=None, later .save()'s will have no effect.
"""
if session_user_id is None:
return Session(None)
else:
path = '{}.session'.format(session_user_id)
result = Session(session_user_id)
if not file_exists(path):
return result
try:
with open(path, 'r') as file:
data = json.load(file)
result.port = data.get('port', result.port)
result.salt = data.get('salt', result.salt)
# Keep while migrating from unsigned to signed salt
if result.salt > 0:
result.salt = struct.unpack(
'q', struct.pack('Q', result.salt))[0]
result.layer = data.get('layer', result.layer)
result.server_address = \
data.get('server_address', result.server_address)
# FIXME We need to import the AuthKey here or otherwise
# we get cyclic dependencies.
from ..crypto import AuthKey
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)
result.entities = EntityDatabase(data.get('entities', []))
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass
return result
def generate_sequence(self, content_related): def generate_sequence(self, content_related):
"""Thread safe method to generates the next sequence number, """Thread safe method to generates the next sequence number,