mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-23 09:53:45 +03:00
491 lines
17 KiB
Python
491 lines
17 KiB
Python
import json
|
|
import os
|
|
import platform
|
|
import sqlite3
|
|
import struct
|
|
import time
|
|
from base64 import b64decode
|
|
from enum import Enum
|
|
from os.path import isfile as file_exists
|
|
from threading import Lock, RLock
|
|
|
|
from . import utils
|
|
from .crypto import AuthKey
|
|
from .tl import TLObject
|
|
from .tl.types import (
|
|
PeerUser, PeerChat, PeerChannel,
|
|
InputPeerUser, InputPeerChat, InputPeerChannel,
|
|
InputPhoto, InputDocument
|
|
)
|
|
|
|
EXTENSION = '.session'
|
|
CURRENT_VERSION = 3 # database version
|
|
|
|
|
|
class _SentFileType(Enum):
|
|
DOCUMENT = 0
|
|
PHOTO = 1
|
|
|
|
@staticmethod
|
|
def from_type(cls):
|
|
if cls == InputDocument:
|
|
return _SentFileType.DOCUMENT
|
|
elif cls == InputPhoto:
|
|
return _SentFileType.PHOTO
|
|
else:
|
|
raise ValueError('The cls must be either InputDocument/InputPhoto')
|
|
|
|
|
|
class Session:
|
|
"""This session contains the required information to login into your
|
|
Telegram account. NEVER give the saved JSON file to anyone, since
|
|
they would gain instant access to all your messages and contacts.
|
|
|
|
If you think the session has been compromised, close all the sessions
|
|
through an official Telegram client to revoke the authorization.
|
|
"""
|
|
def __init__(self, session_id):
|
|
"""session_user_id should either be a string or another Session.
|
|
Note that if another session is given, only parameters like
|
|
those required to init a connection will be copied.
|
|
"""
|
|
# These values will NOT be saved
|
|
self.filename = ':memory:'
|
|
|
|
# For connection purposes
|
|
if isinstance(session_id, Session):
|
|
self.device_model = session_id.device_model
|
|
self.system_version = session_id.system_version
|
|
self.app_version = session_id.app_version
|
|
self.lang_code = session_id.lang_code
|
|
self.system_lang_code = session_id.system_lang_code
|
|
self.lang_pack = session_id.lang_pack
|
|
self.report_errors = session_id.report_errors
|
|
self.save_entities = session_id.save_entities
|
|
self.flood_sleep_threshold = session_id.flood_sleep_threshold
|
|
else: # str / None
|
|
if session_id:
|
|
self.filename = session_id
|
|
if not self.filename.endswith(EXTENSION):
|
|
self.filename += EXTENSION
|
|
|
|
system = platform.uname()
|
|
self.device_model = system.system or 'Unknown'
|
|
self.system_version = system.release or '1.0'
|
|
self.app_version = '1.0' # '0' will provoke error
|
|
self.lang_code = 'en'
|
|
self.system_lang_code = self.lang_code
|
|
self.lang_pack = ''
|
|
self.report_errors = True
|
|
self.save_entities = True
|
|
self.flood_sleep_threshold = 60
|
|
|
|
self.id = struct.unpack('q', os.urandom(8))[0]
|
|
self._sequence = 0
|
|
self.time_offset = 0
|
|
self._last_msg_id = 0 # Long
|
|
self.salt = 0 # Long
|
|
|
|
# Cross-thread safety
|
|
self._seq_no_lock = Lock()
|
|
self._msg_id_lock = Lock()
|
|
self._db_lock = RLock()
|
|
|
|
# These values will be saved
|
|
self._dc_id = 0
|
|
self._server_address = None
|
|
self._port = None
|
|
self._auth_key = None
|
|
|
|
# Migrating from .json -> SQL
|
|
entities = self._check_migrate_json()
|
|
|
|
self._conn = None
|
|
c = self._cursor()
|
|
c.execute("select name from sqlite_master "
|
|
"where type='table' and name='version'")
|
|
if c.fetchone():
|
|
# Tables already exist, check for the version
|
|
c.execute("select version from version")
|
|
version = c.fetchone()[0]
|
|
if version != CURRENT_VERSION:
|
|
self._upgrade_database(old=version)
|
|
c.execute("delete from version")
|
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
|
self.save()
|
|
|
|
# These values will be saved
|
|
c.execute('select * from sessions')
|
|
tuple_ = c.fetchone()
|
|
if tuple_:
|
|
self._dc_id, self._server_address, self._port, key, = tuple_
|
|
self._auth_key = AuthKey(data=key)
|
|
|
|
c.close()
|
|
else:
|
|
# Tables don't exist, create new ones
|
|
self._create_table(
|
|
c,
|
|
"version (version integer primary key)"
|
|
,
|
|
"""sessions (
|
|
dc_id integer primary key,
|
|
server_address text,
|
|
port integer,
|
|
auth_key blob
|
|
)"""
|
|
,
|
|
"""entities (
|
|
id integer primary key,
|
|
hash integer not null,
|
|
username text,
|
|
phone integer,
|
|
name text
|
|
)"""
|
|
,
|
|
"""sent_files (
|
|
md5_digest blob,
|
|
file_size integer,
|
|
type integer,
|
|
id integer,
|
|
hash integer,
|
|
primary key(md5_digest, file_size, type)
|
|
)"""
|
|
)
|
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
|
# Migrating from JSON -> new table and may have entities
|
|
if entities:
|
|
c.executemany(
|
|
'insert or replace into entities values (?,?,?,?,?)',
|
|
entities
|
|
)
|
|
self._update_session_table()
|
|
c.close()
|
|
self.save()
|
|
|
|
def _check_migrate_json(self):
|
|
if file_exists(self.filename):
|
|
try:
|
|
with open(self.filename, encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
self.delete() # Delete JSON file to create database
|
|
|
|
self._port = data.get('port', self._port)
|
|
self._server_address = \
|
|
data.get('server_address', self._server_address)
|
|
|
|
if data.get('auth_key_data', None) is not None:
|
|
key = b64decode(data['auth_key_data'])
|
|
self._auth_key = AuthKey(data=key)
|
|
|
|
rows = []
|
|
for p_id, p_hash in data.get('entities', []):
|
|
if p_hash is not None:
|
|
rows.append((p_id, p_hash, None, None, None))
|
|
return rows
|
|
except UnicodeDecodeError:
|
|
return [] # No entities
|
|
|
|
def _upgrade_database(self, old):
|
|
c = self._cursor()
|
|
# old == 1 doesn't have the old sent_files so no need to drop
|
|
if old == 2:
|
|
# Old cache from old sent_files lasts then a day anyway, drop
|
|
c.execute('drop table sent_files')
|
|
self._create_table(c, """sent_files (
|
|
md5_digest blob,
|
|
file_size integer,
|
|
type integer,
|
|
id integer,
|
|
hash integer,
|
|
primary key(md5_digest, file_size, type)
|
|
)""")
|
|
c.close()
|
|
|
|
@staticmethod
|
|
def _create_table(c, *definitions):
|
|
"""
|
|
Creates a table given its definition 'name (columns).
|
|
If the sqlite version is >= 3.8.2, it will use "without rowid".
|
|
See http://www.sqlite.org/releaselog/3_8_2.html.
|
|
"""
|
|
required = (3, 8, 2)
|
|
sqlite_v = tuple(int(x) for x in sqlite3.sqlite_version.split('.'))
|
|
extra = ' without rowid' if sqlite_v >= required else ''
|
|
for definition in definitions:
|
|
c.execute('create table {}{}'.format(definition, extra))
|
|
|
|
# Data from sessions should be kept as properties
|
|
# not to fetch the database every time we need it
|
|
def set_dc(self, dc_id, server_address, port):
|
|
self._dc_id = dc_id
|
|
self._server_address = server_address
|
|
self._port = port
|
|
self._update_session_table()
|
|
|
|
# Fetch the auth_key corresponding to this data center
|
|
c = self._cursor()
|
|
c.execute('select auth_key from sessions')
|
|
tuple_ = c.fetchone()
|
|
if tuple_ and tuple_[0]:
|
|
self._auth_key = AuthKey(data=tuple_[0])
|
|
else:
|
|
self._auth_key = None
|
|
c.close()
|
|
|
|
@property
|
|
def server_address(self):
|
|
return self._server_address
|
|
|
|
@property
|
|
def port(self):
|
|
return self._port
|
|
|
|
@property
|
|
def auth_key(self):
|
|
return self._auth_key
|
|
|
|
@auth_key.setter
|
|
def auth_key(self, value):
|
|
self._auth_key = value
|
|
self._update_session_table()
|
|
|
|
def _update_session_table(self):
|
|
with self._db_lock:
|
|
c = self._cursor()
|
|
# While we can save multiple rows into the sessions table
|
|
# currently we only want to keep ONE as the tables don't
|
|
# tell us which auth_key's are usable and will work. Needs
|
|
# some more work before being able to save auth_key's for
|
|
# multiple DCs. Probably done differently.
|
|
c.execute('delete from sessions')
|
|
c.execute('insert or replace into sessions values (?,?,?,?)', (
|
|
self._dc_id,
|
|
self._server_address,
|
|
self._port,
|
|
self._auth_key.key if self._auth_key else b''
|
|
))
|
|
c.close()
|
|
|
|
def save(self):
|
|
"""Saves the current session object as session_user_id.session"""
|
|
with self._db_lock:
|
|
self._conn.commit()
|
|
|
|
def _cursor(self):
|
|
"""Asserts that the connection is open and returns a cursor"""
|
|
with self._db_lock:
|
|
if self._conn is None:
|
|
self._conn = sqlite3.connect(self.filename,
|
|
check_same_thread=False)
|
|
return self._conn.cursor()
|
|
|
|
def close(self):
|
|
"""Closes the connection unless we're working in-memory"""
|
|
if self.filename != ':memory:':
|
|
with self._db_lock:
|
|
if self._conn is not None:
|
|
self._conn.close()
|
|
self._conn = None
|
|
|
|
def delete(self):
|
|
"""Deletes the current session file"""
|
|
if self.filename == ':memory:':
|
|
return True
|
|
try:
|
|
os.remove(self.filename)
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
@staticmethod
|
|
def list_sessions():
|
|
"""Lists all the sessions of the users who have ever connected
|
|
using this client and never logged out
|
|
"""
|
|
return [os.path.splitext(os.path.basename(f))[0]
|
|
for f in os.listdir('.') if f.endswith(EXTENSION)]
|
|
|
|
def generate_sequence(self, content_related):
|
|
"""Thread safe method to generates the next sequence number,
|
|
based on whether it was confirmed yet or not.
|
|
|
|
Note that if confirmed=True, the sequence number
|
|
will be increased by one too
|
|
"""
|
|
with self._seq_no_lock:
|
|
if content_related:
|
|
result = self._sequence * 2 + 1
|
|
self._sequence += 1
|
|
return result
|
|
else:
|
|
return self._sequence * 2
|
|
|
|
def get_new_msg_id(self):
|
|
"""Generates a new unique message ID based on the current
|
|
time (in ms) since epoch"""
|
|
# Refer to mtproto_plain_sender.py for the original method
|
|
now = time.time()
|
|
nanoseconds = int((now - int(now)) * 1e+9)
|
|
# "message identifiers are divisible by 4"
|
|
new_msg_id = ((int(now) + self.time_offset) << 32) | (nanoseconds << 2)
|
|
|
|
with self._msg_id_lock:
|
|
if self._last_msg_id >= new_msg_id:
|
|
new_msg_id = self._last_msg_id + 4
|
|
|
|
self._last_msg_id = new_msg_id
|
|
|
|
return new_msg_id
|
|
|
|
def update_time_offset(self, correct_msg_id):
|
|
"""Updates the time offset based on a known correct message ID"""
|
|
now = int(time.time())
|
|
correct = correct_msg_id >> 32
|
|
self.time_offset = correct - now
|
|
|
|
# Entity processing
|
|
|
|
def process_entities(self, tlo):
|
|
"""Processes all the found entities on the given TLObject,
|
|
unless .enabled is False.
|
|
|
|
Returns True if new input entities were added.
|
|
"""
|
|
if not self.save_entities:
|
|
return
|
|
|
|
if not isinstance(tlo, TLObject) and hasattr(tlo, '__iter__'):
|
|
# This may be a list of users already for instance
|
|
entities = tlo
|
|
else:
|
|
entities = []
|
|
if hasattr(tlo, 'chats') and hasattr(tlo.chats, '__iter__'):
|
|
entities.extend(tlo.chats)
|
|
if hasattr(tlo, 'users') and hasattr(tlo.users, '__iter__'):
|
|
entities.extend(tlo.users)
|
|
if not entities:
|
|
return
|
|
|
|
rows = [] # Rows to add (id, hash, username, phone, name)
|
|
for e in entities:
|
|
if not isinstance(e, TLObject):
|
|
continue
|
|
try:
|
|
p = utils.get_input_peer(e, allow_self=False)
|
|
marked_id = utils.get_peer_id(p)
|
|
except ValueError:
|
|
continue
|
|
|
|
if isinstance(p, (InputPeerUser, InputPeerChannel)):
|
|
if not p.access_hash:
|
|
# Some users and channels seem to be returned without
|
|
# an 'access_hash', meaning Telegram doesn't want you
|
|
# to access them. This is the reason behind ensuring
|
|
# that the 'access_hash' is non-zero. See issue #354.
|
|
# Note that this checks for zero or None, see #392.
|
|
continue
|
|
else:
|
|
p_hash = p.access_hash
|
|
elif isinstance(p, InputPeerChat):
|
|
p_hash = 0
|
|
else:
|
|
continue
|
|
|
|
username = getattr(e, 'username', None) or None
|
|
if username is not None:
|
|
username = username.lower()
|
|
phone = getattr(e, 'phone', None)
|
|
name = utils.get_display_name(e) or None
|
|
rows.append((marked_id, p_hash, username, phone, name))
|
|
if not rows:
|
|
return
|
|
|
|
with self._db_lock:
|
|
self._cursor().executemany(
|
|
'insert or replace into entities values (?,?,?,?,?)', rows
|
|
)
|
|
self.save()
|
|
|
|
def get_input_entity(self, key):
|
|
"""Parses the given string, integer or TLObject key into a
|
|
marked entity ID, which is then used to fetch the hash
|
|
from the database.
|
|
|
|
If a callable key is given, every row will be fetched,
|
|
and passed as a tuple to a function, that should return
|
|
a true-like value when the desired row is found.
|
|
|
|
Raises ValueError if it cannot be found.
|
|
"""
|
|
try:
|
|
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
|
|
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
|
|
# We already have an Input version, so nothing else required
|
|
return key
|
|
# Try to early return if this key can be casted as input peer
|
|
return utils.get_input_peer(key)
|
|
except (AttributeError, TypeError):
|
|
# Not a TLObject or can't be cast into InputPeer
|
|
if isinstance(key, TLObject):
|
|
key = utils.get_peer_id(key)
|
|
|
|
c = self._cursor()
|
|
if isinstance(key, str):
|
|
phone = utils.parse_phone(key)
|
|
if phone:
|
|
c.execute('select id, hash from entities where phone=?',
|
|
(phone,))
|
|
else:
|
|
username, _ = utils.parse_username(key)
|
|
if username:
|
|
c.execute('select id, hash from entities where username=?',
|
|
(username,))
|
|
|
|
if isinstance(key, int):
|
|
c.execute('select id, hash from entities where id=?', (key,))
|
|
|
|
result = c.fetchone()
|
|
if not result and isinstance(key, str):
|
|
# Try exact match by name if phone/username failed
|
|
c.execute('select id, hash from entities where name=?', (key,))
|
|
result = c.fetchone()
|
|
|
|
c.close()
|
|
if result:
|
|
i, h = result # unpack resulting tuple
|
|
i, k = utils.resolve_id(i) # removes the mark and returns kind
|
|
if k == PeerUser:
|
|
return InputPeerUser(i, h)
|
|
elif k == PeerChat:
|
|
return InputPeerChat(i)
|
|
elif k == PeerChannel:
|
|
return InputPeerChannel(i, h)
|
|
else:
|
|
raise ValueError('Could not find input entity with key ', key)
|
|
|
|
# File processing
|
|
|
|
def get_file(self, md5_digest, file_size, cls):
|
|
tuple_ = self._cursor().execute(
|
|
'select id, hash from sent_files '
|
|
'where md5_digest = ? and file_size = ? and type = ?',
|
|
(md5_digest, file_size, _SentFileType.from_type(cls).value)
|
|
).fetchone()
|
|
if tuple_:
|
|
# Both allowed classes have (id, access_hash) as parameters
|
|
return cls(tuple_[0], tuple_[1])
|
|
|
|
def cache_file(self, md5_digest, file_size, instance):
|
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
|
raise TypeError('Cannot cache %s instance' % type(instance))
|
|
|
|
with self._db_lock:
|
|
self._cursor().execute(
|
|
'insert or replace into sent_files values (?,?,?,?,?)', (
|
|
md5_digest, file_size,
|
|
_SentFileType.from_type(type(instance)).value,
|
|
instance.id, instance.access_hash
|
|
))
|
|
self.save()
|