mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 05:04:33 +03:00
Use sqlite3 instead JSON for the session files
This commit is contained in:
parent
b11c2e885b
commit
664417b409
|
@ -92,7 +92,7 @@ class TelegramBareClient:
|
|||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
session = Session.try_load_or_create_new(session)
|
||||
session = Session(session)
|
||||
elif not isinstance(session, Session):
|
||||
raise ValueError(
|
||||
'The given session must be a str or a Session instance.'
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import os
|
||||
import platform
|
||||
import sqlite3
|
||||
import struct
|
||||
import time
|
||||
from base64 import b64encode, b64decode
|
||||
from base64 import b64decode
|
||||
from os.path import isfile as file_exists
|
||||
from threading import Lock
|
||||
|
||||
from .entity_database import EntityDatabase
|
||||
from .. import helpers
|
||||
|
||||
EXTENSION = '.session'
|
||||
CURRENT_VERSION = 1 # database version
|
||||
|
||||
|
||||
class Session:
|
||||
"""This session contains the required information to login into your
|
||||
|
@ -25,6 +29,7 @@ class Session:
|
|||
those required to init a connection will be copied.
|
||||
"""
|
||||
# These values will NOT be saved
|
||||
self.filename = ':memory:'
|
||||
if isinstance(session_user_id, Session):
|
||||
self.session_user_id = None
|
||||
|
||||
|
@ -41,7 +46,10 @@ class Session:
|
|||
self.flood_sleep_threshold = session.flood_sleep_threshold
|
||||
|
||||
else: # str / None
|
||||
self.session_user_id = session_user_id
|
||||
if session_user_id:
|
||||
self.filename = session_user_id
|
||||
if not self.filename.endswith(EXTENSION):
|
||||
self.filename += EXTENSION
|
||||
|
||||
system = platform.uname()
|
||||
self.device_model = system.system if system.system else 'Unknown'
|
||||
|
@ -54,49 +62,172 @@ class Session:
|
|||
self.save_entities = True
|
||||
self.flood_sleep_threshold = 60
|
||||
|
||||
# These values will be saved
|
||||
self._server_address = None
|
||||
self._port = None
|
||||
self._auth_key = None
|
||||
self._layer = 0
|
||||
self._salt = 0 # Signed long
|
||||
self.entities = EntityDatabase() # Known and cached entities
|
||||
|
||||
# Cross-thread safety
|
||||
self._seq_no_lock = Lock()
|
||||
self._msg_id_lock = Lock()
|
||||
self._save_lock = Lock()
|
||||
self._db_lock = Lock()
|
||||
|
||||
# Migrating from .json -> SQL
|
||||
self._check_migrate_json()
|
||||
|
||||
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
|
||||
c = self._conn.cursor()
|
||||
c.execute("select name from sqlite_master "
|
||||
"where type='table' and name='version'")
|
||||
if c.fetchone():
|
||||
# Tables already exist, check for the version
|
||||
c.execute("select version from version")
|
||||
version = c.fetchone()[0]
|
||||
if version != CURRENT_VERSION:
|
||||
self._upgrade_database(old=version)
|
||||
self.save()
|
||||
|
||||
# These values will be saved
|
||||
c.execute('select * from sessions')
|
||||
self._server_address, self._port, key, \
|
||||
self._layer, self._salt = c.fetchone()
|
||||
|
||||
from ..crypto import AuthKey
|
||||
self._auth_key = AuthKey(data=key)
|
||||
c.close()
|
||||
else:
|
||||
# Tables don't exist, create new ones
|
||||
c.execute("create table version (version integer)")
|
||||
c.execute(
|
||||
"""create table sessions (
|
||||
server_address text,
|
||||
port integer,
|
||||
auth_key blob,
|
||||
layer integer,
|
||||
salt integer
|
||||
)"""
|
||||
)
|
||||
c.execute(
|
||||
"""create table entities (
|
||||
id integer,
|
||||
hash integer,
|
||||
username text,
|
||||
phone integer,
|
||||
name text
|
||||
)"""
|
||||
)
|
||||
c.execute("insert into version values (1)")
|
||||
c.close()
|
||||
self.save()
|
||||
|
||||
self.id = helpers.generate_random_long(signed=True)
|
||||
self._sequence = 0
|
||||
self.time_offset = 0
|
||||
self._last_msg_id = 0 # Long
|
||||
|
||||
# These values will be saved
|
||||
self.server_address = None
|
||||
self.port = None
|
||||
self.auth_key = None
|
||||
self.layer = 0
|
||||
self.salt = 0 # Signed long
|
||||
self.entities = EntityDatabase() # Known and cached entities
|
||||
def _check_migrate_json(self):
|
||||
if file_exists(self.filename):
|
||||
try:
|
||||
with open(self.filename, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self._port = data.get('port', self._port)
|
||||
self._salt = data.get('salt', self._salt)
|
||||
# Keep while migrating from unsigned to signed salt
|
||||
if self._salt > 0:
|
||||
self._salt = struct.unpack(
|
||||
'q', struct.pack('Q', self._salt))[0]
|
||||
|
||||
self._layer = data.get('layer', self._layer)
|
||||
self._server_address = \
|
||||
data.get('server_address', self._server_address)
|
||||
|
||||
from ..crypto import AuthKey
|
||||
if data.get('auth_key_data', None) is not None:
|
||||
key = b64decode(data['auth_key_data'])
|
||||
self._auth_key = AuthKey(data=key)
|
||||
|
||||
self.entities = EntityDatabase(data.get('entities', []))
|
||||
self.delete() # Delete JSON file to create database
|
||||
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
|
||||
pass
|
||||
|
||||
def _upgrade_database(self, old):
|
||||
pass
|
||||
|
||||
# Data from sessions should be kept as properties
|
||||
# not to fetch the database every time we need it
|
||||
@property
|
||||
def server_address(self):
|
||||
return self._server_address
|
||||
|
||||
@server_address.setter
|
||||
def server_address(self, value):
|
||||
self._server_address = value
|
||||
self._update_session_table()
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
|
||||
@port.setter
|
||||
def port(self, value):
|
||||
self._port = value
|
||||
self._update_session_table()
|
||||
|
||||
@property
|
||||
def auth_key(self):
|
||||
return self._auth_key
|
||||
|
||||
@auth_key.setter
|
||||
def auth_key(self, value):
|
||||
self._auth_key = value
|
||||
self._update_session_table()
|
||||
|
||||
@property
|
||||
def layer(self):
|
||||
return self._layer
|
||||
|
||||
@layer.setter
|
||||
def layer(self, value):
|
||||
self._layer = value
|
||||
self._update_session_table()
|
||||
|
||||
@property
|
||||
def salt(self):
|
||||
return self._salt
|
||||
|
||||
@salt.setter
|
||||
def salt(self, value):
|
||||
self._salt = value
|
||||
self._update_session_table()
|
||||
|
||||
def _update_session_table(self):
|
||||
with self._db_lock:
|
||||
c = self._conn.cursor()
|
||||
c.execute('delete from sessions')
|
||||
c.execute('insert into sessions values (?,?,?,?,?)', (
|
||||
self._server_address,
|
||||
self._port,
|
||||
self._auth_key.key if self._auth_key else b'',
|
||||
self._layer,
|
||||
self._salt
|
||||
))
|
||||
c.close()
|
||||
|
||||
def save(self):
|
||||
"""Saves the current session object as session_user_id.session"""
|
||||
if not self.session_user_id or self._save_lock.locked():
|
||||
return
|
||||
|
||||
with self._save_lock:
|
||||
with open('{}.session'.format(self.session_user_id), 'w') as file:
|
||||
out_dict = {
|
||||
'port': self.port,
|
||||
'salt': self.salt,
|
||||
'layer': self.layer,
|
||||
'server_address': self.server_address,
|
||||
'auth_key_data':
|
||||
b64encode(self.auth_key.key).decode('ascii')
|
||||
if self.auth_key else None
|
||||
}
|
||||
if self.save_entities:
|
||||
out_dict['entities'] = self.entities.get_input_list()
|
||||
|
||||
json.dump(out_dict, file)
|
||||
with self._db_lock:
|
||||
self._conn.commit()
|
||||
|
||||
def delete(self):
|
||||
"""Deletes the current session file"""
|
||||
if self.filename == ':memory:':
|
||||
return True
|
||||
try:
|
||||
os.remove('{}.session'.format(self.session_user_id))
|
||||
os.remove(self.filename)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
@ -107,48 +238,7 @@ class Session:
|
|||
using this client and never logged out
|
||||
"""
|
||||
return [os.path.splitext(os.path.basename(f))[0]
|
||||
for f in os.listdir('.') if f.endswith('.session')]
|
||||
|
||||
@staticmethod
|
||||
def try_load_or_create_new(session_user_id):
|
||||
"""Loads a saved session_user_id.session or creates a new one.
|
||||
If session_user_id=None, later .save()'s will have no effect.
|
||||
"""
|
||||
if session_user_id is None:
|
||||
return Session(None)
|
||||
else:
|
||||
path = '{}.session'.format(session_user_id)
|
||||
result = Session(session_user_id)
|
||||
if not file_exists(path):
|
||||
return result
|
||||
|
||||
try:
|
||||
with open(path, 'r') as file:
|
||||
data = json.load(file)
|
||||
result.port = data.get('port', result.port)
|
||||
result.salt = data.get('salt', result.salt)
|
||||
# Keep while migrating from unsigned to signed salt
|
||||
if result.salt > 0:
|
||||
result.salt = struct.unpack(
|
||||
'q', struct.pack('Q', result.salt))[0]
|
||||
|
||||
result.layer = data.get('layer', result.layer)
|
||||
result.server_address = \
|
||||
data.get('server_address', result.server_address)
|
||||
|
||||
# FIXME We need to import the AuthKey here or otherwise
|
||||
# we get cyclic dependencies.
|
||||
from ..crypto import AuthKey
|
||||
if data.get('auth_key_data', None) is not None:
|
||||
key = b64decode(data['auth_key_data'])
|
||||
result.auth_key = AuthKey(data=key)
|
||||
|
||||
result.entities = EntityDatabase(data.get('entities', []))
|
||||
|
||||
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
return result
|
||||
for f in os.listdir('.') if f.endswith(EXTENSION)]
|
||||
|
||||
def generate_sequence(self, content_related):
|
||||
"""Thread safe method to generates the next sequence number,
|
||||
|
|
Loading…
Reference in New Issue
Block a user