Telethon/telethon/tl/session.py

284 lines
9.3 KiB
Python
Raw Normal View History

import json
import os
import platform
import sqlite3
2017-12-03 23:10:22 +03:00
import struct
2016-11-30 00:29:42 +03:00
import time
from base64 import b64decode
2016-11-30 00:29:42 +03:00
from os.path import isfile as file_exists
from threading import Lock
2016-11-30 00:29:42 +03:00
from .entity_database import EntityDatabase
from .. import helpers
2016-08-28 15:16:52 +03:00
EXTENSION = '.session'
CURRENT_VERSION = 1 # database version
2016-08-28 15:16:52 +03:00
class Session:
"""This session contains the required information to login into your
Telegram account. NEVER give the saved JSON file to anyone, since
they would gain instant access to all your messages and contacts.
If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization.
"""
2017-12-26 18:59:30 +03:00
def __init__(self, session_id):
"""session_user_id should either be a string or another Session.
Note that if another session is given, only parameters like
those required to init a connection will be copied.
"""
# These values will NOT be saved
self.filename = ':memory:'
2017-12-26 18:59:30 +03:00
# For connection purposes
if isinstance(session_id, Session):
self.device_model = session_id.device_model
self.system_version = session_id.system_version
self.app_version = session_id.app_version
self.lang_code = session_id.lang_code
self.system_lang_code = session_id.system_lang_code
self.lang_pack = session_id.lang_pack
self.report_errors = session_id.report_errors
self.save_entities = session_id.save_entities
self.flood_sleep_threshold = session_id.flood_sleep_threshold
else: # str / None
2017-12-26 18:59:30 +03:00
if session_id:
self.filename = session_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
system = platform.uname()
2017-12-26 18:59:30 +03:00
self.device_model = system.system or 'Unknown'
self.system_version = system.release or '1.0'
self.app_version = '1.0' # '0' will provoke error
self.lang_code = 'en'
2017-06-30 12:48:45 +03:00
self.system_lang_code = self.lang_code
self.lang_pack = ''
self.report_errors = True
self.save_entities = True
self.flood_sleep_threshold = 60
2017-12-26 18:59:30 +03:00
self.id = helpers.generate_random_long(signed=True)
self._sequence = 0
self.time_offset = 0
self._last_msg_id = 0 # Long
# Cross-thread safety
self._seq_no_lock = Lock()
self._msg_id_lock = Lock()
self._db_lock = Lock()
# These values will be saved
self._server_address = None
self._port = None
self._auth_key = None
self._layer = 0
self._salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
# Migrating from .json -> SQL
self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False)
c = self._conn.cursor()
c.execute("select name from sqlite_master "
"where type='table' and name='version'")
if c.fetchone():
# Tables already exist, check for the version
c.execute("select version from version")
version = c.fetchone()[0]
if version != CURRENT_VERSION:
self._upgrade_database(old=version)
self.save()
# These values will be saved
c.execute('select * from sessions')
self._server_address, self._port, key, \
self._layer, self._salt = c.fetchone()
from ..crypto import AuthKey
self._auth_key = AuthKey(data=key)
c.close()
else:
# Tables don't exist, create new ones
c.execute("create table version (version integer)")
c.execute(
"""create table sessions (
server_address text,
port integer,
auth_key blob,
layer integer,
salt integer
)"""
)
c.execute(
"""create table entities (
id integer,
hash integer,
username text,
phone integer,
name text
)"""
)
c.execute("insert into version values (1)")
c.close()
self.save()
def _check_migrate_json(self):
if file_exists(self.filename):
try:
with open(self.filename, encoding='utf-8') as f:
data = json.load(f)
self._port = data.get('port', self._port)
self._salt = data.get('salt', self._salt)
# Keep while migrating from unsigned to signed salt
if self._salt > 0:
self._salt = struct.unpack(
'q', struct.pack('Q', self._salt))[0]
self._layer = data.get('layer', self._layer)
self._server_address = \
data.get('server_address', self._server_address)
from ..crypto import AuthKey
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key)
self.entities = EntityDatabase(data.get('entities', []))
self.delete() # Delete JSON file to create database
except (UnicodeDecodeError, json.decoder.JSONDecodeError):
pass
def _upgrade_database(self, old):
pass
# Data from sessions should be kept as properties
# not to fetch the database every time we need it
@property
def server_address(self):
return self._server_address
@server_address.setter
def server_address(self, value):
self._server_address = value
self._update_session_table()
@property
def port(self):
return self._port
@port.setter
def port(self, value):
self._port = value
self._update_session_table()
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@property
def layer(self):
return self._layer
@layer.setter
def layer(self, value):
self._layer = value
self._update_session_table()
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
self._update_session_table()
def _update_session_table(self):
with self._db_lock:
c = self._conn.cursor()
c.execute('delete from sessions')
c.execute('insert into sessions values (?,?,?,?,?)', (
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b'',
self._layer,
self._salt
))
c.close()
def save(self):
"""Saves the current session object as session_user_id.session"""
with self._db_lock:
self._conn.commit()
def delete(self):
"""Deletes the current session file"""
if self.filename == ':memory:':
return True
try:
os.remove(self.filename)
return True
except OSError:
return False
@staticmethod
def list_sessions():
"""Lists all the sessions of the users who have ever connected
using this client and never logged out
"""
return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith(EXTENSION)]
def generate_sequence(self, content_related):
"""Thread safe method to generates the next sequence number,
based on whether it was confirmed yet or not.
Note that if confirmed=True, the sequence number
will be increased by one too
"""
with self._seq_no_lock:
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
2017-06-26 12:00:43 +03:00
# Refer to mtproto_plain_sender.py for the original method
now = time.time()
nanoseconds = int((now - int(now)) * 1e+9)
# "message identifiers are divisible by 4"
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
with self._msg_id_lock:
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
"""Updates the time offset based on a known correct message ID"""
now = int(time.time())
correct = correct_msg_id >> 32
self.time_offset = correct - now
def process_entities(self, tlobject):
try:
if self.entities.process(tlobject):
self.save() # Save if any new entities got added
except:
pass