Use sqlite3 instead JSON for the session files

This commit is contained in:
Lonami Exo 2017-12-26 16:45:47 +01:00
parent b11c2e885b
commit 664417b409
2 changed files with 162 additions and 72 deletions

View File

@ -92,7 +92,7 @@ class TelegramBareClient:
# Determine what session object we have
if isinstance(session, str) or session is None:
session = Session.try_load_or_create_new(session)
session = Session(session)
elif not isinstance(session, Session):
raise ValueError(
'The given session must be a str or a Session instance.'

View File

@ -1,15 +1,19 @@
import json
import os
import platform
import sqlite3
import struct
import time
from base64 import b64encode, b64decode
from base64 import b64decode
from os.path import isfile as file_exists
from threading import Lock
from .entity_database import EntityDatabase
from .. import helpers
EXTENSION = '.session'
CURRENT_VERSION = 1 # database version
class Session:
"""This session contains the required information to login into your
@ -25,6 +29,7 @@ class Session:
those required to init a connection will be copied.
"""
# These values will NOT be saved
self.filename = ':memory:'
if isinstance(session_user_id, Session):
self.session_user_id = None
@ -41,7 +46,10 @@ class Session:
self.flood_sleep_threshold = session.flood_sleep_threshold
else: # str / None
self.session_user_id = session_user_id
if session_user_id:
self.filename = session_user_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
system = platform.uname()
self.device_model = system.system if system.system else 'Unknown'
@ -54,49 +62,172 @@ class Session:
self.save_entities = True
self.flood_sleep_threshold = 60
# These values will be saved
self._server_address = None
self._port = None
self._auth_key = None
self._layer = 0
self._salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
# Cross-thread safety
self._seq_no_lock = Lock()
self._msg_id_lock = Lock()
self._save_lock = Lock()
self._db_lock = Lock()
# Migrating from .json -> SQL
self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
c = self._conn.cursor()
c.execute("select name from sqlite_master "
"where type='table' and name='version'")
if c.fetchone():
# Tables already exist, check for the version
c.execute("select version from version")
version = c.fetchone()[0]
if version != CURRENT_VERSION:
self._upgrade_database(old=version)
self.save()
# These values will be saved
c.execute('select * from sessions')
self._server_address, self._port, key, \
self._layer, self._salt = c.fetchone()
from ..crypto import AuthKey
self._auth_key = AuthKey(data=key)
c.close()
else:
# Tables don't exist, create new ones
c.execute("create table version (version integer)")
c.execute(
"""create table sessions (
server_address text,
port integer,
auth_key blob,
layer integer,
salt integer
)"""
)
c.execute(
"""create table entities (
id integer,
hash integer,
username text,
phone integer,
name text
)"""
)
c.execute("insert into version values (1)")
c.close()
self.save()
self.id = helpers.generate_random_long(signed=True)
self._sequence = 0
self.time_offset = 0
self._last_msg_id = 0 # Long
# These values will be saved
self.server_address = None
self.port = None
self.auth_key = None
self.layer = 0
self.salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
def _check_migrate_json(self):
if file_exists(self.filename):
try:
with open(self.filename, encoding='utf-8') as f:
data = json.load(f)
self._port = data.get('port', self._port)
self._salt = data.get('salt', self._salt)
# Keep while migrating from unsigned to signed salt
if self._salt > 0:
self._salt = struct.unpack(
'q', struct.pack('Q', self._salt))[0]
self._layer = data.get('layer', self._layer)
self._server_address = \
data.get('server_address', self._server_address)
from ..crypto import AuthKey
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key)
self.entities = EntityDatabase(data.get('entities', []))
self.delete() # Delete JSON file to create database
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
pass
def _upgrade_database(self, old):
pass
# Data from sessions should be kept as properties
# not to fetch the database every time we need it
@property
def server_address(self):
return self._server_address
@server_address.setter
def server_address(self, value):
self._server_address = value
self._update_session_table()
@property
def port(self):
return self._port
@port.setter
def port(self, value):
self._port = value
self._update_session_table()
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@property
def layer(self):
return self._layer
@layer.setter
def layer(self, value):
self._layer = value
self._update_session_table()
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
self._update_session_table()
def _update_session_table(self):
with self._db_lock:
c = self._conn.cursor()
c.execute('delete from sessions')
c.execute('insert into sessions values (?,?,?,?,?)', (
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b'',
self._layer,
self._salt
))
c.close()
def save(self):
"""Saves the current session object as session_user_id.session"""
if not self.session_user_id or self._save_lock.locked():
return
with self._save_lock:
with open('{}.session'.format(self.session_user_id), 'w') as file:
out_dict = {
'port': self.port,
'salt': self.salt,
'layer': self.layer,
'server_address': self.server_address,
'auth_key_data':
b64encode(self.auth_key.key).decode('ascii')
if self.auth_key else None
}
if self.save_entities:
out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file)
with self._db_lock:
self._conn.commit()
def delete(self):
"""Deletes the current session file"""
if self.filename == ':memory:':
return True
try:
os.remove('{}.session'.format(self.session_user_id))
os.remove(self.filename)
return True
except OSError:
return False
@ -107,48 +238,7 @@ class Session:
using this client and never logged out
"""
return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith('.session')]
@staticmethod
def try_load_or_create_new(session_user_id):
"""Loads a saved session_user_id.session or creates a new one.
If session_user_id=None, later .save()'s will have no effect.
"""
if session_user_id is None:
return Session(None)
else:
path = '{}.session'.format(session_user_id)
result = Session(session_user_id)
if not file_exists(path):
return result
try:
with open(path, 'r') as file:
data = json.load(file)
result.port = data.get('port', result.port)
result.salt = data.get('salt', result.salt)
# Keep while migrating from unsigned to signed salt
if result.salt > 0:
result.salt = struct.unpack(
'q', struct.pack('Q', result.salt))[0]
result.layer = data.get('layer', result.layer)
result.server_address = \
data.get('server_address', result.server_address)
# FIXME We need to import the AuthKey here or otherwise
# we get cyclic dependencies.
from ..crypto import AuthKey
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)
result.entities = EntityDatabase(data.get('entities', []))
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass
return result
for f in os.listdir('.') if f.endswith(EXTENSION)]
def generate_sequence(self, content_related):
"""Thread safe method to generates the next sequence number,