Abstract Session class (merge #657 from tulir/sessions)

This commit is contained in:
Lonami 2018-03-03 11:33:47 +01:00 committed by GitHub
commit 30f7a49263
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 727 additions and 255 deletions

View File

@ -1,3 +1,4 @@
cryptg cryptg
pysocks pysocks
hachoir3 hachoir3
sqlalchemy

View File

@ -25,29 +25,89 @@ file, so that you can quickly access them by username or phone number.
If you're not going to work with updates, or don't need to cache the If you're not going to work with updates, or don't need to cache the
``access_hash`` associated with the entities' ID, you can disable this ``access_hash`` associated with the entities' ID, you can disable this
by setting ``client.session.save_entities = False``, or pass it as a by setting ``client.session.save_entities = False``.
parameter to the ``TelegramClient``.
If you don't want to save the files as a database, you can also create Custom Session Storage
your custom ``Session`` subclass and override the ``.save()`` and ``.load()`` ----------------------
methods. For example, you could save it on a database:
If you don't want to use the default SQLite session storage, you can also use
one of the other implementations or implement your own storage.
To use a custom session storage, simply pass the custom session instance to
``TelegramClient`` instead of the session name.
Currently, there are three implementations of the abstract ``Session`` class:
* MemorySession. Stores session data in Python variables.
* SQLiteSession, the default. Stores sessions in their own SQLite databases.
* AlchemySession. Stores all sessions in a single database via SQLAlchemy.
Using AlchemySession
~~~~~~~~~~~~~~~~~~~~
The AlchemySession implementation can store multiple Sessions in the same
database, but to do this, each session instance needs to have access to the
same models and database session.
To get started, you need to create an ``AlchemySessionContainer`` which will
contain that shared data. The simplest way to use ``AlchemySessionContainer``
is to simply pass it the database URL:
.. code-block:: python .. code-block:: python
class DatabaseSession(Session): container = AlchemySessionContainer('mysql://user:pass@localhost/telethon')
def save():
# serialize relevant data to the database
def load(): If you already have SQLAlchemy set up for your own project, you can also pass
# load relevant data to the database the engine separately:
.. code-block:: python
my_sqlalchemy_engine = sqlalchemy.create_engine('...')
container = AlchemySessionContainer(engine=my_sqlalchemy_engine)
By default, the session container will manage table creation/schema updates/etc
automatically. If you want to manage everything yourself, you can pass your
SQLAlchemy Session and ``declarative_base`` instances and set ``manage_tables``
to ``False``:
.. code-block:: python
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import orm
import sqlalchemy
...
session_factory = orm.sessionmaker(bind=my_sqlalchemy_engine)
session = session_factory()
my_base = declarative_base()
...
container = AlchemySessionContainer(session=session, table_base=my_base, manage_tables=False)
You always need to provide either ``engine`` or ``session`` to the container.
If you set ``manage_tables=False`` and provide a ``session``, ``engine`` is not
needed. In any other case, ``engine`` is always required.
After you have your ``AlchemySessionContainer`` instance created, you can
create new sessions by calling ``new_session``:
.. code-block:: python
session = container.new_session('some session id')
client = TelegramClient(session)
where ``some session id`` is an unique identifier for the session.
Creating your own storage
~~~~~~~~~~~~~~~~~~~~~~~~~
The easiest way to create your own implementation is to use MemorySession as
the base and check out how ``SQLiteSession`` or ``AlchemySession`` work. You
can find the relevant Python files under the ``sessions`` directory.
You should read the ````session.py```` source file to know what "relevant SQLite Sessions and Heroku
data" you need to keep track of. --------------------------
Sessions and Heroku
-------------------
You probably have a newer version of SQLite installed (>= 3.8.2). Heroku uses You probably have a newer version of SQLite installed (>= 3.8.2). Heroku uses
SQLite 3.7.9 which does not support ``WITHOUT ROWID``. So, if you generated SQLite 3.7.9 which does not support ``WITHOUT ROWID``. So, if you generated
@ -59,8 +119,8 @@ session file on your Heroku dyno itself. The most complicated is creating
a custom buildpack to install SQLite >= 3.8.2. a custom buildpack to install SQLite >= 3.8.2.
Generating a Session File on a Heroku Dyno Generating a SQLite Session File on a Heroku Dyno
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. note:: .. note::
Due to Heroku's ephemeral filesystem all dynamically generated Due to Heroku's ephemeral filesystem all dynamically generated

View File

@ -151,7 +151,8 @@ def main():
]), ]),
install_requires=['pyaes', 'rsa'], install_requires=['pyaes', 'rsa'],
extras_require={ extras_require={
'cryptg': ['cryptg'] 'cryptg': ['cryptg'],
'sqlalchemy': ['sqlalchemy']
} }
) )

View File

@ -402,13 +402,13 @@ class MtProtoSender:
elif bad_msg.error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64 self.session.sequence += 64
__log__.info('Attempting to set the right higher sequence') __log__.info('Attempting to set the right higher sequence')
self._resend_request(bad_msg.bad_msg_id) self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16 self.session.sequence -= 16
__log__.info('Attempting to set the right lower sequence') __log__.info('Attempting to set the right lower sequence')
self._resend_request(bad_msg.bad_msg_id) self._resend_request(bad_msg.bad_msg_id)
return True return True

View File

@ -0,0 +1,4 @@
from .abstract import Session
from .memory import MemorySession
from .sqlite import SQLiteSession
from .sqlalchemy import AlchemySessionContainer, AlchemySession

View File

@ -0,0 +1,147 @@
from abc import ABC, abstractmethod
import time
import struct
import os
class Session(ABC):
def __init__(self):
self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0
self._time_offset = 0
self._salt = 0
self._report_errors = True
self._flood_sleep_threshold = 60
def clone(self, to_instance=None):
cloned = to_instance or self.__class__()
cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold
return cloned
@abstractmethod
def set_dc(self, dc_id, server_address, port):
raise NotImplementedError
@property
@abstractmethod
def server_address(self):
raise NotImplementedError
@property
@abstractmethod
def port(self):
raise NotImplementedError
@property
@abstractmethod
def auth_key(self):
raise NotImplementedError
@auth_key.setter
@abstractmethod
def auth_key(self, value):
raise NotImplementedError
@abstractmethod
def close(self):
raise NotImplementedError
@abstractmethod
def save(self):
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError
@classmethod
@abstractmethod
def list_sessions(cls):
raise NotImplementedError
@abstractmethod
def process_entities(self, tlo):
raise NotImplementedError
@abstractmethod
def get_input_entity(self, key):
raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
raise NotImplementedError
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
@property
def report_errors(self):
return self._report_errors
@report_errors.setter
def report_errors(self, value):
self._report_errors = value
@property
def time_offset(self):
return self._time_offset
@time_offset.setter
def time_offset(self, value):
self._time_offset = value
@property
def flood_sleep_threshold(self):
return self._flood_sleep_threshold
@flood_sleep_threshold.setter
def flood_sleep_threshold(self, value):
self._flood_sleep_threshold = value
@property
def sequence(self):
return self._sequence
@sequence.setter
def sequence(self, value):
self._sequence = value
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
now = time.time() + self._time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
now = int(time.time())
correct = correct_msg_id >> 32
self._time_offset = correct - now
self._last_msg_id = 0
def generate_sequence(self, content_related):
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2

204
telethon/sessions/memory.py Normal file
View File

@ -0,0 +1,204 @@
from enum import Enum
from .. import utils
from .abstract import Session
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
)
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class MemorySession(Session):
def __init__(self):
super().__init__()
self._dc_id = None
self._server_address = None
self._port = None
self._auth_key = None
self._files = {}
self._entities = set()
def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id
self._server_address = server_address
self._port = port
@property
def server_address(self):
return self._server_address
@property
def port(self):
return self._port
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
def close(self):
pass
def save(self):
pass
def delete(self):
pass
@classmethod
def list_sessions(cls):
raise NotImplementedError
def _entity_values_to_row(self, id, hash, username, phone, name):
return id, hash, username, phone, name
def _entity_to_row(self, e):
if not isinstance(e, TLObject):
return
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
return
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
return
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
return
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
return self._entity_values_to_row(marked_id, p_hash, username, phone, name)
def _entities_to_rows(self, tlo):
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
row = self._entity_to_row(e)
if row:
rows.append(row)
return rows
def process_entities(self, tlo):
self._entities += set(self._entities_to_rows(tlo))
def get_entity_rows_by_phone(self, phone):
rows = [(id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone]
return rows[0] if rows else None
def get_entity_rows_by_username(self, username):
rows = [(id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username]
return rows[0] if rows else None
def get_entity_rows_by_name(self, name):
rows = [(id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name]
return rows[0] if rows else None
def get_entity_rows_by_id(self, id):
rows = [(id, hash) for found_id, hash, _, _, _
in self._entities if found_id == id]
return rows[0] if rows else None
def get_input_entity(self, key):
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject):
key = utils.get_peer_id(key)
result = None
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
result = self.get_entity_rows_by_phone(phone)
else:
username, _ = utils.parse_username(key)
if username:
result = self.get_entity_rows_by_username(username)
if isinstance(key, int):
result = self.get_entity_rows_by_id(key)
if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key)
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(instance))
value = (instance.id, instance.access_hash)
self._files[key] = value
def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls))
try:
return self._files[key]
except KeyError:
return None

View File

@ -0,0 +1,225 @@
try:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, BLOB, orm
import sqlalchemy as sql
except ImportError:
sql = None
pass
from ..crypto import AuthKey
from ..tl.types import InputPhoto, InputDocument
from .memory import MemorySession, _SentFileType
LATEST_VERSION = 1
class AlchemySessionContainer:
def __init__(self, engine=None, session=None, table_prefix='',
table_base=None, manage_tables=True):
if not sql:
raise ImportError('SQLAlchemy not imported')
if isinstance(engine, str):
engine = sql.create_engine(engine)
self.db_engine = engine
if not session:
db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
else:
self.db = session
table_base = table_base or declarative_base()
(self.Version, self.Session, self.Entity,
self.SentFile) = self.create_table_classes(self.db, table_prefix,
table_base)
if manage_tables:
table_base.metadata.bind = self.db_engine
if not self.db_engine.dialect.has_table(self.db_engine,
self.Version.__tablename__):
table_base.metadata.create_all()
self.db.add(self.Version(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
@staticmethod
def create_table_classes(db, prefix, Base):
class Version(Base):
query = db.query_property()
__tablename__ = '{prefix}version'.format(prefix=prefix)
version = Column(Integer, primary_key=True)
class Session(Base):
query = db.query_property()
__tablename__ = '{prefix}sessions'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(BLOB)
class Entity(Base):
query = db.query_property()
__tablename__ = '{prefix}entities'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
hash = Column(Integer, nullable=False)
username = Column(String)
phone = Column(Integer)
name = Column(String)
class SentFile(Base):
query = db.query_property()
__tablename__ = '{prefix}sent_files'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
md5_digest = Column(BLOB, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
return Version, Session, Entity, SentFile
def check_and_upgrade_database(self):
row = self.Version.query.all()
version = row[0].version if row else 1
if version == LATEST_VERSION:
return
self.Version.query.delete()
# Implement table schema updates here and increase version
self.db.add(self.Version(version=version))
self.db.commit()
def new_session(self, session_id):
return AlchemySession(self, session_id)
def list_sessions(self):
return
def save(self):
self.db.commit()
class AlchemySession(MemorySession):
def __init__(self, container, session_id):
super().__init__()
self.container = container
self.db = container.db
self.Version, self.Session, self.Entity, self.SentFile = (
container.Version, container.Session, container.Entity,
container.SentFile)
self.session_id = session_id
self._load_session()
def _load_session(self):
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session:
self._dc_id = session.dc_id
self._server_address = session.server_address
self._port = session.port
self._auth_key = AuthKey(data=session.auth_key)
def clone(self, to_instance=None):
cloned = to_instance or self.__class__(self.container, self.session_id)
return super().clone(cloned)
def set_dc(self, dc_id, server_address, port):
super().set_dc(dc_id, server_address, port)
self._update_session_table()
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session and session.auth_key:
self._auth_key = AuthKey(data=session.auth_key)
else:
self._auth_key = None
@MemorySession.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
def _update_session_table(self):
self.Session.query.filter(
self.Session.session_id == self.session_id).delete()
new = self.Session(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address,
port=self._port,
auth_key=(self._auth_key.key
if self._auth_key else b''))
self.db.merge(new)
def _db_query(self, dbclass, *args):
return dbclass.query.filter(dbclass.session_id == self.session_id,
*args)
def save(self):
self.container.save()
def close(self):
# Nothing to do here, connection is managed by AlchemySessionContainer.
pass
def delete(self):
self._db_query(self.Session).delete()
self._db_query(self.Entity).delete()
self._db_query(self.SentFile).delete()
def _entity_values_to_row(self, id, hash, username, phone, name):
return self.Entity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name)
def process_entities(self, tlo):
rows = self._entities_to_rows(tlo)
if not rows:
return
for row in rows:
self.db.merge(row)
self.save()
def get_entity_rows_by_phone(self, key):
row = self._db_query(self.Entity,
self.Entity.phone == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_username(self, key):
row = self._db_query(self.Entity,
self.Entity.username == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_name(self, key):
row = self._db_query(self.Entity,
self.Entity.name == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_id(self, key):
row = self._db_query(self.Entity, self.Entity.id == key).one_or_none()
return row.id, row.hash if row else None
def get_file(self, md5_digest, file_size, cls):
row = self._db_query(self.SentFile,
self.SentFile.md5_digest == md5_digest,
self.SentFile.file_size == file_size,
self.SentFile.type == _SentFileType.from_type(
cls).value).one_or_none()
return row.id, row.hash if row else None
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
self.db.merge(
self.SentFile(session_id=self.session_id, md5_digest=md5_digest,
type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash))
self.save()

View File

@ -1,20 +1,13 @@
import json import json
import os import os
import platform
import sqlite3 import sqlite3
import struct
import time
from base64 import b64decode from base64 import b64decode
from enum import Enum
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock, RLock from threading import Lock, RLock
from . import utils from .memory import MemorySession, _SentFileType
from .crypto import AuthKey from ..crypto import AuthKey
from .tl import TLObject from ..tl.types import (
from .tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument InputPhoto, InputDocument
) )
@ -22,21 +15,7 @@ EXTENSION = '.session'
CURRENT_VERSION = 3 # database version CURRENT_VERSION = 3 # database version
class _SentFileType(Enum): class SQLiteSession(MemorySession):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class Session:
"""This session contains the required information to login into your """This session contains the required information to login into your
Telegram account. NEVER give the saved JSON file to anyone, since Telegram account. NEVER give the saved JSON file to anyone, since
they would gain instant access to all your messages and contacts. they would gain instant access to all your messages and contacts.
@ -44,59 +23,27 @@ class Session:
If you think the session has been compromised, close all the sessions If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization. through an official Telegram client to revoke the authorization.
""" """
def __init__(self, session_id):
def __init__(self, session_id=None):
super().__init__()
"""session_user_id should either be a string or another Session. """session_user_id should either be a string or another Session.
Note that if another session is given, only parameters like Note that if another session is given, only parameters like
those required to init a connection will be copied. those required to init a connection will be copied.
""" """
# These values will NOT be saved # These values will NOT be saved
self.filename = ':memory:' self.filename = ':memory:'
self.save_entities = True
# For connection purposes if session_id:
if isinstance(session_id, Session): self.filename = session_id
self.device_model = session_id.device_model if not self.filename.endswith(EXTENSION):
self.system_version = session_id.system_version self.filename += EXTENSION
self.app_version = session_id.app_version
self.lang_code = session_id.lang_code
self.system_lang_code = session_id.system_lang_code
self.lang_pack = session_id.lang_pack
self.report_errors = session_id.report_errors
self.save_entities = session_id.save_entities
self.flood_sleep_threshold = session_id.flood_sleep_threshold
else: # str / None
if session_id:
self.filename = session_id
if not self.filename.endswith(EXTENSION):
self.filename += EXTENSION
system = platform.uname()
self.device_model = system.system or 'Unknown'
self.system_version = system.release or '1.0'
self.app_version = '1.0' # '0' will provoke error
self.lang_code = 'en'
self.system_lang_code = self.lang_code
self.lang_pack = ''
self.report_errors = True
self.save_entities = True
self.flood_sleep_threshold = 60
self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0
self.time_offset = 0
self._last_msg_id = 0 # Long
self.salt = 0 # Long
# Cross-thread safety # Cross-thread safety
self._seq_no_lock = Lock() self._seq_no_lock = Lock()
self._msg_id_lock = Lock() self._msg_id_lock = Lock()
self._db_lock = RLock() self._db_lock = RLock()
# These values will be saved
self._dc_id = 0
self._server_address = None
self._port = None
self._auth_key = None
# Migrating from .json -> SQL # Migrating from .json -> SQL
entities = self._check_migrate_json() entities = self._check_migrate_json()
@ -163,6 +110,11 @@ class Session:
c.close() c.close()
self.save() self.save()
def clone(self, to_instance=None):
cloned = super().clone(to_instance)
cloned.save_entities = self.save_entities
return cloned
def _check_migrate_json(self): def _check_migrate_json(self):
if file_exists(self.filename): if file_exists(self.filename):
try: try:
@ -218,9 +170,7 @@ class Session:
# Data from sessions should be kept as properties # Data from sessions should be kept as properties
# not to fetch the database every time we need it # not to fetch the database every time we need it
def set_dc(self, dc_id, server_address, port): def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id super().set_dc(dc_id, server_address, port)
self._server_address = server_address
self._port = port
self._update_session_table() self._update_session_table()
# Fetch the auth_key corresponding to this data center # Fetch the auth_key corresponding to this data center
@ -233,19 +183,7 @@ class Session:
self._auth_key = None self._auth_key = None
c.close() c.close()
@property @MemorySession.auth_key.setter
def server_address(self):
return self._server_address
@property
def port(self):
return self._port
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
self._update_session_table() self._update_session_table()
@ -298,53 +236,14 @@ class Session:
except OSError: except OSError:
return False return False
@staticmethod @classmethod
def list_sessions(): def list_sessions(cls):
"""Lists all the sessions of the users who have ever connected """Lists all the sessions of the users who have ever connected
using this client and never logged out using this client and never logged out
""" """
return [os.path.splitext(os.path.basename(f))[0] return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith(EXTENSION)] for f in os.listdir('.') if f.endswith(EXTENSION)]
def generate_sequence(self, content_related):
"""Thread safe method to generates the next sequence number,
based on whether it was confirmed yet or not.
Note that if confirmed=True, the sequence number
will be increased by one too
"""
with self._seq_no_lock:
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
# Refer to mtproto_plain_sender.py for the original method
now = time.time() + self.time_offset
nanoseconds = int((now - int(now)) * 1e+9)
# "message identifiers are divisible by 4"
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
with self._msg_id_lock:
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
"""Updates the time offset based on a known correct message ID"""
now = int(time.time())
correct = correct_msg_id >> 32
self.time_offset = correct - now
self._last_msg_id = 0
# Entity processing # Entity processing
def process_entities(self, tlo): def process_entities(self, tlo):
@ -356,49 +255,7 @@ class Session:
if not self.save_entities: if not self.save_entities:
return return
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): rows = self._entities_to_rows(tlo)
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
continue
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
continue
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
continue
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
if not rows: if not rows:
return return
@ -408,62 +265,29 @@ class Session:
) )
self.save() self.save()
def get_input_entity(self, key): def _fetchone_entity(self, query, args):
"""Parses the given string, integer or TLObject key into a
marked entity ID, which is then used to fetch the hash
from the database.
If a callable key is given, every row will be fetched,
and passed as a tuple to a function, that should return
a true-like value when the desired row is found.
Raises ValueError if it cannot be found.
"""
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject):
key = utils.get_peer_id(key)
c = self._cursor() c = self._cursor()
if isinstance(key, str): c.execute(query, args)
phone = utils.parse_phone(key) return c.fetchone()
if phone:
c.execute('select id, hash from entities where phone=?',
(phone,))
else:
username, _ = utils.parse_username(key)
if username:
c.execute('select id, hash from entities where username=?',
(username,))
if isinstance(key, int): def get_entity_rows_by_phone(self, phone):
c.execute('select id, hash from entities where id=?', (key,)) return self._fetchone_entity(
'select id, hash from entities where phone=?', (phone,))
result = c.fetchone() def get_entity_rows_by_username(self, username):
if not result and isinstance(key, str): return self._fetchone_entity(
# Try exact match by name if phone/username failed 'select id, hash from entities where username=?',
c.execute('select id, hash from entities where name=?', (key,)) (username,))
result = c.fetchone()
c.close() def get_entity_rows_by_name(self, name):
if result: return self._fetchone_entity(
i, h = result # unpack resulting tuple 'select id, hash from entities where name=?',
i, k = utils.resolve_id(i) # removes the mark and returns kind (name,))
if k == PeerUser:
return InputPeerUser(i, h) def get_entity_rows_by_id(self, id):
elif k == PeerChat: return self._fetchone_entity(
return InputPeerChat(i) 'select id, hash from entities where id=?',
elif k == PeerChannel: (id,))
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)
# File processing # File processing

View File

@ -1,11 +1,11 @@
import logging import logging
import os import os
import platform
import threading import threading
from datetime import timedelta, datetime from datetime import timedelta, datetime
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Lock from threading import Lock
from time import sleep from time import sleep
from . import version, utils from . import version, utils
from .crypto import rsa from .crypto import rsa
from .errors import ( from .errors import (
@ -14,7 +14,7 @@ from .errors import (
PhoneMigrateError, NetworkMigrateError, UserMigrateError PhoneMigrateError, NetworkMigrateError, UserMigrateError
) )
from .network import authenticator, MtProtoSender, Connection, ConnectionMode from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .session import Session from .sessions import Session, SQLiteSession
from .tl import TLObject from .tl import TLObject
from .tl.all_tlobjects import LAYER from .tl.all_tlobjects import LAYER
from .tl.functions import ( from .tl.functions import (
@ -73,7 +73,12 @@ class TelegramBareClient:
update_workers=None, update_workers=None,
spawn_read_thread=False, spawn_read_thread=False,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
**kwargs): loop=None,
device_model=None,
system_version=None,
app_version=None,
lang_code='en',
system_lang_code='en'):
"""Refer to TelegramClient.__init__ for docs on this method""" """Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash: if not api_id or not api_hash:
raise ValueError( raise ValueError(
@ -81,10 +86,10 @@ class TelegramBareClient:
"Refer to telethon.rtfd.io for more information.") "Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6 self._use_ipv6 = use_ipv6
# Determine what session object we have # Determine what session object we have
if isinstance(session, str) or session is None: if isinstance(session, str) or session is None:
session = Session(session) session = SQLiteSession(session)
elif not isinstance(session, Session): elif not isinstance(session, Session):
raise TypeError( raise TypeError(
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
@ -125,11 +130,12 @@ class TelegramBareClient:
self.updates = UpdateState(workers=update_workers) self.updates = UpdateState(workers=update_workers)
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__) system = platform.uname()
for name, value in kwargs.items(): self.device_model = device_model or system.system or 'Unknown'
if not hasattr(self.session, name): self.system_version = system_version or system.release or '1.0'
raise ValueError('Unknown named parameter', name) self.app_version = app_version or self.__version__
setattr(self.session, name, value) self.lang_code = lang_code
self.system_lang_code = system_lang_code
# Despite the state of the real connection, keep track of whether # Despite the state of the real connection, keep track of whether
# the user has explicitly called .connect() or .disconnect() here. # the user has explicitly called .connect() or .disconnect() here.
@ -233,11 +239,11 @@ class TelegramBareClient:
"""Wraps query around InvokeWithLayerRequest(InitConnectionRequest())""" """Wraps query around InvokeWithLayerRequest(InitConnectionRequest())"""
return InvokeWithLayerRequest(LAYER, InitConnectionRequest( return InvokeWithLayerRequest(LAYER, InitConnectionRequest(
api_id=self.api_id, api_id=self.api_id,
device_model=self.session.device_model, device_model=self.device_model,
system_version=self.session.system_version, system_version=self.system_version,
app_version=self.session.app_version, app_version=self.app_version,
lang_code=self.session.lang_code, lang_code=self.lang_code,
system_lang_code=self.session.system_lang_code, system_lang_code=self.system_lang_code,
lang_pack='', # "langPacks are for official apps only" lang_pack='', # "langPacks are for official apps only"
query=query query=query
)) ))
@ -361,7 +367,7 @@ class TelegramBareClient:
# #
# Construct this session with the connection parameters # Construct this session with the connection parameters
# (system version, device model...) from the current one. # (system version, device model...) from the current one.
session = Session(self.session) session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port) session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[dc_id] = session self._exported_sessions[dc_id] = session
@ -387,7 +393,7 @@ class TelegramBareClient:
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session: if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True) dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(self.session) session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port) session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session self._exported_sessions[cdn_redirect.dc_id] = session