mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-09 08:00:53 +03:00
Use sqlite3 instead JSON for the session files
This commit is contained in:
parent
b11c2e885b
commit
664417b409
|
@ -92,7 +92,7 @@ class TelegramBareClient:
|
||||||
|
|
||||||
# Determine what session object we have
|
# Determine what session object we have
|
||||||
if isinstance(session, str) or session is None:
|
if isinstance(session, str) or session is None:
|
||||||
session = Session.try_load_or_create_new(session)
|
session = Session(session)
|
||||||
elif not isinstance(session, Session):
|
elif not isinstance(session, Session):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The given session must be a str or a Session instance.'
|
'The given session must be a str or a Session instance.'
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode, b64decode
|
from base64 import b64decode
|
||||||
from os.path import isfile as file_exists
|
from os.path import isfile as file_exists
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
from .entity_database import EntityDatabase
|
from .entity_database import EntityDatabase
|
||||||
from .. import helpers
|
from .. import helpers
|
||||||
|
|
||||||
|
EXTENSION = '.session'
|
||||||
|
CURRENT_VERSION = 1 # database version
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
"""This session contains the required information to login into your
|
"""This session contains the required information to login into your
|
||||||
|
@ -25,6 +29,7 @@ class Session:
|
||||||
those required to init a connection will be copied.
|
those required to init a connection will be copied.
|
||||||
"""
|
"""
|
||||||
# These values will NOT be saved
|
# These values will NOT be saved
|
||||||
|
self.filename = ':memory:'
|
||||||
if isinstance(session_user_id, Session):
|
if isinstance(session_user_id, Session):
|
||||||
self.session_user_id = None
|
self.session_user_id = None
|
||||||
|
|
||||||
|
@ -41,7 +46,10 @@ class Session:
|
||||||
self.flood_sleep_threshold = session.flood_sleep_threshold
|
self.flood_sleep_threshold = session.flood_sleep_threshold
|
||||||
|
|
||||||
else: # str / None
|
else: # str / None
|
||||||
self.session_user_id = session_user_id
|
if session_user_id:
|
||||||
|
self.filename = session_user_id
|
||||||
|
if not self.filename.endswith(EXTENSION):
|
||||||
|
self.filename += EXTENSION
|
||||||
|
|
||||||
system = platform.uname()
|
system = platform.uname()
|
||||||
self.device_model = system.system if system.system else 'Unknown'
|
self.device_model = system.system if system.system else 'Unknown'
|
||||||
|
@ -54,49 +62,172 @@ class Session:
|
||||||
self.save_entities = True
|
self.save_entities = True
|
||||||
self.flood_sleep_threshold = 60
|
self.flood_sleep_threshold = 60
|
||||||
|
|
||||||
|
# These values will be saved
|
||||||
|
self._server_address = None
|
||||||
|
self._port = None
|
||||||
|
self._auth_key = None
|
||||||
|
self._layer = 0
|
||||||
|
self._salt = 0 # Signed long
|
||||||
|
self.entities = EntityDatabase() # Known and cached entities
|
||||||
|
|
||||||
# Cross-thread safety
|
# Cross-thread safety
|
||||||
self._seq_no_lock = Lock()
|
self._seq_no_lock = Lock()
|
||||||
self._msg_id_lock = Lock()
|
self._msg_id_lock = Lock()
|
||||||
self._save_lock = Lock()
|
self._db_lock = Lock()
|
||||||
|
|
||||||
|
# Migrating from .json -> SQL
|
||||||
|
self._check_migrate_json()
|
||||||
|
|
||||||
|
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
|
||||||
|
c = self._conn.cursor()
|
||||||
|
c.execute("select name from sqlite_master "
|
||||||
|
"where type='table' and name='version'")
|
||||||
|
if c.fetchone():
|
||||||
|
# Tables already exist, check for the version
|
||||||
|
c.execute("select version from version")
|
||||||
|
version = c.fetchone()[0]
|
||||||
|
if version != CURRENT_VERSION:
|
||||||
|
self._upgrade_database(old=version)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
# These values will be saved
|
||||||
|
c.execute('select * from sessions')
|
||||||
|
self._server_address, self._port, key, \
|
||||||
|
self._layer, self._salt = c.fetchone()
|
||||||
|
|
||||||
|
from ..crypto import AuthKey
|
||||||
|
self._auth_key = AuthKey(data=key)
|
||||||
|
c.close()
|
||||||
|
else:
|
||||||
|
# Tables don't exist, create new ones
|
||||||
|
c.execute("create table version (version integer)")
|
||||||
|
c.execute(
|
||||||
|
"""create table sessions (
|
||||||
|
server_address text,
|
||||||
|
port integer,
|
||||||
|
auth_key blob,
|
||||||
|
layer integer,
|
||||||
|
salt integer
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
c.execute(
|
||||||
|
"""create table entities (
|
||||||
|
id integer,
|
||||||
|
hash integer,
|
||||||
|
username text,
|
||||||
|
phone integer,
|
||||||
|
name text
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
c.execute("insert into version values (1)")
|
||||||
|
c.close()
|
||||||
|
self.save()
|
||||||
|
|
||||||
self.id = helpers.generate_random_long(signed=True)
|
self.id = helpers.generate_random_long(signed=True)
|
||||||
self._sequence = 0
|
self._sequence = 0
|
||||||
self.time_offset = 0
|
self.time_offset = 0
|
||||||
self._last_msg_id = 0 # Long
|
self._last_msg_id = 0 # Long
|
||||||
|
|
||||||
# These values will be saved
|
def _check_migrate_json(self):
|
||||||
self.server_address = None
|
if file_exists(self.filename):
|
||||||
self.port = None
|
try:
|
||||||
self.auth_key = None
|
with open(self.filename, encoding='utf-8') as f:
|
||||||
self.layer = 0
|
data = json.load(f)
|
||||||
self.salt = 0 # Signed long
|
self._port = data.get('port', self._port)
|
||||||
self.entities = EntityDatabase() # Known and cached entities
|
self._salt = data.get('salt', self._salt)
|
||||||
|
# Keep while migrating from unsigned to signed salt
|
||||||
|
if self._salt > 0:
|
||||||
|
self._salt = struct.unpack(
|
||||||
|
'q', struct.pack('Q', self._salt))[0]
|
||||||
|
|
||||||
|
self._layer = data.get('layer', self._layer)
|
||||||
|
self._server_address = \
|
||||||
|
data.get('server_address', self._server_address)
|
||||||
|
|
||||||
|
from ..crypto import AuthKey
|
||||||
|
if data.get('auth_key_data', None) is not None:
|
||||||
|
key = b64decode(data['auth_key_data'])
|
||||||
|
self._auth_key = AuthKey(data=key)
|
||||||
|
|
||||||
|
self.entities = EntityDatabase(data.get('entities', []))
|
||||||
|
self.delete() # Delete JSON file to create database
|
||||||
|
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _upgrade_database(self, old):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Data from sessions should be kept as properties
|
||||||
|
# not to fetch the database every time we need it
|
||||||
|
@property
|
||||||
|
def server_address(self):
|
||||||
|
return self._server_address
|
||||||
|
|
||||||
|
@server_address.setter
|
||||||
|
def server_address(self, value):
|
||||||
|
self._server_address = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def port(self):
|
||||||
|
return self._port
|
||||||
|
|
||||||
|
@port.setter
|
||||||
|
def port(self, value):
|
||||||
|
self._port = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_key(self):
|
||||||
|
return self._auth_key
|
||||||
|
|
||||||
|
@auth_key.setter
|
||||||
|
def auth_key(self, value):
|
||||||
|
self._auth_key = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layer(self):
|
||||||
|
return self._layer
|
||||||
|
|
||||||
|
@layer.setter
|
||||||
|
def layer(self, value):
|
||||||
|
self._layer = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def salt(self):
|
||||||
|
return self._salt
|
||||||
|
|
||||||
|
@salt.setter
|
||||||
|
def salt(self, value):
|
||||||
|
self._salt = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
def _update_session_table(self):
|
||||||
|
with self._db_lock:
|
||||||
|
c = self._conn.cursor()
|
||||||
|
c.execute('delete from sessions')
|
||||||
|
c.execute('insert into sessions values (?,?,?,?,?)', (
|
||||||
|
self._server_address,
|
||||||
|
self._port,
|
||||||
|
self._auth_key.key if self._auth_key else b'',
|
||||||
|
self._layer,
|
||||||
|
self._salt
|
||||||
|
))
|
||||||
|
c.close()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Saves the current session object as session_user_id.session"""
|
"""Saves the current session object as session_user_id.session"""
|
||||||
if not self.session_user_id or self._save_lock.locked():
|
with self._db_lock:
|
||||||
return
|
self._conn.commit()
|
||||||
|
|
||||||
with self._save_lock:
|
|
||||||
with open('{}.session'.format(self.session_user_id), 'w') as file:
|
|
||||||
out_dict = {
|
|
||||||
'port': self.port,
|
|
||||||
'salt': self.salt,
|
|
||||||
'layer': self.layer,
|
|
||||||
'server_address': self.server_address,
|
|
||||||
'auth_key_data':
|
|
||||||
b64encode(self.auth_key.key).decode('ascii')
|
|
||||||
if self.auth_key else None
|
|
||||||
}
|
|
||||||
if self.save_entities:
|
|
||||||
out_dict['entities'] = self.entities.get_input_list()
|
|
||||||
|
|
||||||
json.dump(out_dict, file)
|
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
"""Deletes the current session file"""
|
"""Deletes the current session file"""
|
||||||
|
if self.filename == ':memory:':
|
||||||
|
return True
|
||||||
try:
|
try:
|
||||||
os.remove('{}.session'.format(self.session_user_id))
|
os.remove(self.filename)
|
||||||
return True
|
return True
|
||||||
except OSError:
|
except OSError:
|
||||||
return False
|
return False
|
||||||
|
@ -107,48 +238,7 @@ class Session:
|
||||||
using this client and never logged out
|
using this client and never logged out
|
||||||
"""
|
"""
|
||||||
return [os.path.splitext(os.path.basename(f))[0]
|
return [os.path.splitext(os.path.basename(f))[0]
|
||||||
for f in os.listdir('.') if f.endswith('.session')]
|
for f in os.listdir('.') if f.endswith(EXTENSION)]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def try_load_or_create_new(session_user_id):
|
|
||||||
"""Loads a saved session_user_id.session or creates a new one.
|
|
||||||
If session_user_id=None, later .save()'s will have no effect.
|
|
||||||
"""
|
|
||||||
if session_user_id is None:
|
|
||||||
return Session(None)
|
|
||||||
else:
|
|
||||||
path = '{}.session'.format(session_user_id)
|
|
||||||
result = Session(session_user_id)
|
|
||||||
if not file_exists(path):
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(path, 'r') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
result.port = data.get('port', result.port)
|
|
||||||
result.salt = data.get('salt', result.salt)
|
|
||||||
# Keep while migrating from unsigned to signed salt
|
|
||||||
if result.salt > 0:
|
|
||||||
result.salt = struct.unpack(
|
|
||||||
'q', struct.pack('Q', result.salt))[0]
|
|
||||||
|
|
||||||
result.layer = data.get('layer', result.layer)
|
|
||||||
result.server_address = \
|
|
||||||
data.get('server_address', result.server_address)
|
|
||||||
|
|
||||||
# FIXME We need to import the AuthKey here or otherwise
|
|
||||||
# we get cyclic dependencies.
|
|
||||||
from ..crypto import AuthKey
|
|
||||||
if data.get('auth_key_data', None) is not None:
|
|
||||||
key = b64decode(data['auth_key_data'])
|
|
||||||
result.auth_key = AuthKey(data=key)
|
|
||||||
|
|
||||||
result.entities = EntityDatabase(data.get('entities', []))
|
|
||||||
|
|
||||||
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def generate_sequence(self, content_related):
|
def generate_sequence(self, content_related):
|
||||||
"""Thread safe method to generates the next sequence number,
|
"""Thread safe method to generates the next sequence number,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user