mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-29 12:53:44 +03:00
Completely overhaul sessions
This commit is contained in:
parent
0b54fa7a25
commit
684f640b60
|
@ -73,6 +73,72 @@ removed. This implies:
|
|||
// TODO provide standalone alternative for this?
|
||||
|
||||
|
||||
Complete overhaul of session files
|
||||
----------------------------------
|
||||
|
||||
If you were using third-party libraries to deal with sessions, you will need to wait for those to
|
||||
be updated. The library will automatically upgrade the SQLite session files to the new version,
|
||||
and the ``StringSession`` remains backward-compatible. The sessions can now be async.
|
||||
|
||||
In case you were relying on the tables used by SQLite (even though these should have been, and
|
||||
will still need to be, treated as an implementation detail), here are the changes:
|
||||
|
||||
* The ``sessions`` table is now correctly split into ``datacenter`` and ``session``.
|
||||
``datacenter`` contains information about a Telegram datacenter, along with its corresponding
|
||||
authorization key, and ``session`` contains information about the update state and user.
|
||||
* The ``entities`` table is now called ``entity`` and stores the ``type`` separatedly.
|
||||
* The ``update_state`` table is now split into ``session`` and ``channel``, which can contain
|
||||
a per-channel ``pts``.
|
||||
|
||||
Because **the new version does not cache usernames, phone numbers and display names**, using these
|
||||
in method calls is now quite expensive. You *should* migrate your code to do the Right Thing and
|
||||
start using identifiers rather than usernames, phone numbers or invite links. This is both simpler
|
||||
and more reliable, because while a user identifier won't change, their username could.
|
||||
|
||||
You can use the following snippet to make a JSON backup (alternatively, you could just copy the
|
||||
``.session`` file and keep it around) in case you want to preserve the cached usernames:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sqlite, json
|
||||
with sqlite3.connect('your.session') as conn, open('entities.json', 'w', encoding='utf-8') as fp:
|
||||
json.dump([
|
||||
{'id': id, 'hash': hash, 'username': username, 'phone': phone, 'name': name, 'date': date}
|
||||
for (id, hash, username, phone, name, date)
|
||||
in conn.execute('select id, hash, username, phone, name, date from entities')
|
||||
], fp)
|
||||
|
||||
The following public methods or properties have also been removed from ``SQLiteSession`` because
|
||||
they no longer make sense:
|
||||
|
||||
* ``list_sessions``. You can ``glob.glob('*.session')`` instead.
|
||||
* ``clone``.
|
||||
|
||||
And the following, which were inherited from ``MemorySession``:
|
||||
|
||||
* ``delete``. You can ``os.remove`` the file instead (preferably after ``client.log_out()``).
|
||||
* ``set_dc``.
|
||||
* ``dc_id``.
|
||||
* ``server_address``.
|
||||
* ``port``.
|
||||
* ``auth_key``.
|
||||
* ``takeout_id``.
|
||||
* ``get_update_state``.
|
||||
* ``set_update_state``.
|
||||
* ``process_entities``.
|
||||
* ``get_entity_rows_by_phone``.
|
||||
* ``get_entity_rows_by_username``.
|
||||
* ``get_entity_rows_by_name``.
|
||||
* ``get_entity_rows_by_id``.
|
||||
* ``get_input_entity``.
|
||||
* ``cache_file``.
|
||||
* ``get_file``.
|
||||
|
||||
You also can no longer set ``client.session.save_entities = False``. The entities must be saved
|
||||
for the library to work properly. If you still don't want it, you should subclass the session and
|
||||
override the methods to do nothing.
|
||||
|
||||
|
||||
The "iter" variant of the client methods have been removed
|
||||
----------------------------------------------------------
|
||||
|
||||
|
|
|
@ -1,167 +1,90 @@
|
|||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Session(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def clone(self, to_instance=None):
|
||||
"""
|
||||
Creates a clone of this session file.
|
||||
"""
|
||||
return to_instance or self.__class__()
|
||||
|
||||
@abstractmethod
|
||||
def set_dc(self, dc_id, server_address, port):
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
"""
|
||||
Sets the information of the data center address and port that
|
||||
the library should connect to, as well as the data center ID,
|
||||
which is currently unused.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dc_id(self):
|
||||
"""
|
||||
Returns the currently-used data center ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def server_address(self):
|
||||
"""
|
||||
Returns the server address where the library should connect to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def port(self):
|
||||
"""
|
||||
Returns the port to which the library should connect to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def auth_key(self):
|
||||
"""
|
||||
Returns an ``AuthKey`` instance associated with the saved
|
||||
data center, or `None` if a new one should be generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@auth_key.setter
|
||||
@abstractmethod
|
||||
def auth_key(self, value):
|
||||
"""
|
||||
Sets the ``AuthKey`` to be used for the saved data center.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def takeout_id(self):
|
||||
"""
|
||||
Returns an ID of the takeout process initialized for this session,
|
||||
or `None` if there's no were any unfinished takeout requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@takeout_id.setter
|
||||
@abstractmethod
|
||||
def takeout_id(self, value):
|
||||
"""
|
||||
Sets the ID of the unfinished takeout process for this session.
|
||||
Store a new or update an existing `DataCenter` with matching ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_update_state(self, entity_id):
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
"""
|
||||
Returns the ``UpdateState`` associated with the given `entity_id`.
|
||||
If the `entity_id` is 0, it should return the ``UpdateState`` for
|
||||
no specific channel (the "general" state). If no state is known
|
||||
it should ``return None``.
|
||||
Get a list of all currently-stored `DataCenter`. Should not contain duplicate ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_update_state(self, entity_id, state):
|
||||
async def set_state(self, state: SessionState):
|
||||
"""
|
||||
Sets the given ``UpdateState`` for the specified `entity_id`, which
|
||||
should be 0 if the ``UpdateState`` is the "general" state (and not
|
||||
for any specific channel).
|
||||
Set the state about the current session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
"""
|
||||
Called on client disconnection. Should be used to
|
||||
free any used resources. Can be left empty if none.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""
|
||||
Called whenever important properties change. It should
|
||||
make persist the relevant session information to disk.
|
||||
Get the state about the current session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self):
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
"""
|
||||
Called upon client.log_out(). Should delete the stored
|
||||
information from disk since it's not valid anymore.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def list_sessions(cls):
|
||||
"""
|
||||
Lists available sessions. Not used by the library itself.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def process_entities(self, tlo):
|
||||
"""
|
||||
Processes the input ``TLObject`` or ``list`` and saves
|
||||
whatever information is relevant (e.g., ID or access hash).
|
||||
Store a new or update an existing `ChannelState` with matching ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_input_entity(self, key):
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
"""
|
||||
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``).
|
||||
The library uses this method whenever an ``InputPeer`` is needed
|
||||
to suit several purposes (e.g. user only provided its ID or wishes
|
||||
to use a cached username to avoid extra RPC).
|
||||
Get a list of all currently-stored `ChannelState`. Should not contain duplicate ``id``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def cache_file(self, md5_digest, file_size, instance):
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
"""
|
||||
Caches the given file information persistently, so that it
|
||||
doesn't need to be re-uploaded in case the file is used again.
|
||||
Store new or update existing `Entity` with matching ``id``.
|
||||
|
||||
The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``,
|
||||
both with an ``.id`` and ``.access_hash`` attributes.
|
||||
Entities should be saved on a best-effort. It is okay to not save them, although the
|
||||
library may need to do extra work if a previously-saved entity is missing, or even be
|
||||
unable to continue without the entity.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_file(self, md5_digest, file_size, cls):
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
"""
|
||||
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size``
|
||||
match an existing saved record. The class will either be an
|
||||
``InputPhoto`` or ``InputDocument``, both with two parameters
|
||||
``id`` and ``access_hash`` in that order.
|
||||
Get the `Entity` with matching ``ty`` and ``id``.
|
||||
|
||||
The following groups of ``ty`` should be treated to be equivalent, that is, for a given
|
||||
``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
|
||||
that ``id`` from within any ``ty`` in that group should be returned.
|
||||
|
||||
* ``'U'`` and ``'B'`` (user and bot).
|
||||
* ``'G'`` (small group chat).
|
||||
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel).
|
||||
|
||||
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
|
||||
the corresponding ``access_hash`` should still be returned.
|
||||
|
||||
You may use `types.canonical_entity_type` to find out the canonical type.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
"""
|
||||
Save the session.
|
||||
|
||||
May do nothing if the other methods already saved when they were called.
|
||||
|
||||
May return custom data when manual saving is intended.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,230 +1,47 @@
|
|||
from enum import Enum
|
||||
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
from .abstract import Session
|
||||
from .._misc import utils, tlobject
|
||||
from .. import _tl
|
||||
|
||||
|
||||
class _SentFileType(Enum):
|
||||
DOCUMENT = 0
|
||||
PHOTO = 1
|
||||
|
||||
@staticmethod
|
||||
def from_type(cls):
|
||||
if cls == _tl.InputDocument:
|
||||
return _SentFileType.DOCUMENT
|
||||
elif cls == _tl.InputPhoto:
|
||||
return _SentFileType.PHOTO
|
||||
else:
|
||||
raise ValueError('The cls must be either InputDocument/InputPhoto')
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class MemorySession(Session):
|
||||
__slots__ = ('dcs', 'state', 'channel_states', 'entities')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dcs = {}
|
||||
self.state = None
|
||||
self.channel_states = {}
|
||||
self.entities = {}
|
||||
|
||||
self._dc_id = 0
|
||||
self._server_address = None
|
||||
self._port = None
|
||||
self._auth_key = None
|
||||
self._takeout_id = None
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
self.dcs[dc.id] = dc
|
||||
|
||||
self._files = {}
|
||||
self._entities = set()
|
||||
self._update_states = {}
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
return list(self.dcs.values())
|
||||
|
||||
def set_dc(self, dc_id, server_address, port):
|
||||
self._dc_id = dc_id or 0
|
||||
self._server_address = server_address
|
||||
self._port = port
|
||||
async def set_state(self, state: SessionState):
|
||||
self.state = state
|
||||
|
||||
@property
|
||||
def dc_id(self):
|
||||
return self._dc_id
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
return self.state
|
||||
|
||||
@property
|
||||
def server_address(self):
|
||||
return self._server_address
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
self.channel_states[state.channel_id] = state
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
return list(self.channel_states.values())
|
||||
|
||||
@property
|
||||
def auth_key(self):
|
||||
return self._auth_key
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
|
||||
|
||||
@auth_key.setter
|
||||
def auth_key(self, value):
|
||||
self._auth_key = value
|
||||
|
||||
@property
|
||||
def takeout_id(self):
|
||||
return self._takeout_id
|
||||
|
||||
@takeout_id.setter
|
||||
def takeout_id(self, value):
|
||||
self._takeout_id = value
|
||||
|
||||
def get_update_state(self, entity_id):
|
||||
return self._update_states.get(entity_id, None)
|
||||
|
||||
def set_update_state(self, entity_id, state):
|
||||
self._update_states[entity_id] = state
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def delete(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _entity_values_to_row(id, hash, username, phone, name):
|
||||
# While this is a simple implementation it might be overrode by,
|
||||
# other classes so they don't need to implement the plural form
|
||||
# of the method. Don't remove.
|
||||
return id, hash, username, phone, name
|
||||
|
||||
def _entity_to_row(self, e):
|
||||
if not isinstance(e, tlobject.TLObject):
|
||||
return
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
try:
|
||||
p = utils.get_input_peer(e, allow_self=False)
|
||||
marked_id = utils.get_peer_id(p)
|
||||
except TypeError:
|
||||
# Note: `get_input_peer` already checks for non-zero `access_hash`.
|
||||
# See issues #354 and #392. It also checks that the entity
|
||||
# is not `min`, because its `access_hash` cannot be used
|
||||
# anywhere (since layer 102, there are two access hashes).
|
||||
return
|
||||
|
||||
if isinstance(p, (_tl.InputPeerUser, _tl.InputPeerChannel)):
|
||||
p_hash = p.access_hash
|
||||
elif isinstance(p, _tl.InputPeerChat):
|
||||
p_hash = 0
|
||||
else:
|
||||
return
|
||||
|
||||
username = getattr(e, 'username', None) or None
|
||||
if username is not None:
|
||||
username = username.lower()
|
||||
phone = getattr(e, 'phone', None)
|
||||
name = utils.get_display_name(e) or None
|
||||
return self._entity_values_to_row(
|
||||
marked_id, p_hash, username, phone, name
|
||||
)
|
||||
|
||||
def _entities_to_rows(self, tlo):
|
||||
if not isinstance(tlo, tlobject.TLObject) and utils.is_list_like(tlo):
|
||||
# This may be a list of users already for instance
|
||||
entities = tlo
|
||||
else:
|
||||
entities = []
|
||||
if hasattr(tlo, 'user'):
|
||||
entities.append(tlo.user)
|
||||
if hasattr(tlo, 'chat'):
|
||||
entities.append(tlo.chat)
|
||||
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
|
||||
entities.extend(tlo.chats)
|
||||
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
|
||||
entities.extend(tlo.users)
|
||||
|
||||
rows = [] # Rows to add (id, hash, username, phone, name)
|
||||
for e in entities:
|
||||
row = self._entity_to_row(e)
|
||||
if row:
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def process_entities(self, tlo):
|
||||
self._entities |= set(self._entities_to_rows(tlo))
|
||||
|
||||
def get_entity_rows_by_phone(self, phone):
|
||||
try:
|
||||
return next((id, hash) for id, hash, _, found_phone, _
|
||||
in self._entities if found_phone == phone)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def get_entity_rows_by_username(self, username):
|
||||
try:
|
||||
return next((id, hash) for id, hash, found_username, _, _
|
||||
in self._entities if found_username == username)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def get_entity_rows_by_name(self, name):
|
||||
try:
|
||||
return next((id, hash) for id, hash, _, _, found_name
|
||||
in self._entities if found_name == name)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def get_entity_rows_by_id(self, id, exact=True):
|
||||
try:
|
||||
return next((id, hash) for found_id, hash, _, _, _
|
||||
in self._entities if found_id == id)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def get_input_entity(self, key):
|
||||
try:
|
||||
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
|
||||
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
|
||||
# We already have an Input version, so nothing else required
|
||||
return key
|
||||
# Try to early return if this key can be casted as input peer
|
||||
return utils.get_input_peer(key)
|
||||
except (AttributeError, TypeError):
|
||||
# Not a TLObject or can't be cast into InputPeer
|
||||
if isinstance(key, tlobject.TLObject):
|
||||
key = utils.get_peer_id(key)
|
||||
exact = True
|
||||
else:
|
||||
exact = not isinstance(key, int) or key < 0
|
||||
|
||||
result = None
|
||||
if isinstance(key, str):
|
||||
phone = utils.parse_phone(key)
|
||||
if phone:
|
||||
result = self.get_entity_rows_by_phone(phone)
|
||||
else:
|
||||
username, invite = utils.parse_username(key)
|
||||
if username and not invite:
|
||||
result = self.get_entity_rows_by_username(username)
|
||||
|
||||
elif isinstance(key, int):
|
||||
result = self.get_entity_rows_by_id(key, exact)
|
||||
|
||||
if not result and isinstance(key, str):
|
||||
result = self.get_entity_rows_by_name(key)
|
||||
|
||||
if result:
|
||||
entity_id, entity_hash = result # unpack resulting tuple
|
||||
entity_id, kind = utils.resolve_id(entity_id)
|
||||
# removes the mark and returns type of entity
|
||||
if kind == _tl.PeerUser:
|
||||
return _tl.InputPeerUser(entity_id, entity_hash)
|
||||
elif kind == _tl.PeerChat:
|
||||
return _tl.InputPeerChat(entity_id)
|
||||
elif kind == _tl.PeerChannel:
|
||||
return _tl.InputPeerChannel(entity_id, entity_hash)
|
||||
else:
|
||||
raise ValueError('Could not find input entity with key ', key)
|
||||
|
||||
def cache_file(self, md5_digest, file_size, instance):
|
||||
if not isinstance(instance, (_tl.InputDocument, _tl.InputPhoto)):
|
||||
raise TypeError('Cannot cache %s instance' % type(instance))
|
||||
key = (md5_digest, file_size, _SentFileType.from_type(type(instance)))
|
||||
value = (instance.id, instance.access_hash)
|
||||
self._files[key] = value
|
||||
|
||||
def get_file(self, md5_digest, file_size, cls):
|
||||
key = (md5_digest, file_size, _SentFileType.from_type(cls))
|
||||
try:
|
||||
return cls(*self._files[key])
|
||||
ty, access_hash = self.entities[id]
|
||||
return Entity(ty, id, access_hash)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def save(self):
|
||||
pass
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import datetime
|
||||
import os
|
||||
import time
|
||||
import ipaddress
|
||||
from typing import Optional, List
|
||||
|
||||
from .memory import MemorySession, _SentFileType
|
||||
from .abstract import Session
|
||||
from .._misc import utils
|
||||
from .. import _tl
|
||||
from .._crypto import AuthKey
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
try:
|
||||
import sqlite3
|
||||
|
@ -15,16 +17,17 @@ except ImportError as e:
|
|||
sqlite3_err = type(e)
|
||||
|
||||
EXTENSION = '.session'
|
||||
CURRENT_VERSION = 7 # database version
|
||||
CURRENT_VERSION = 8 # database version
|
||||
|
||||
|
||||
class SQLiteSession(MemorySession):
|
||||
"""This session contains the required information to login into your
|
||||
Telegram account. NEVER give the saved session file to anyone, since
|
||||
they would gain instant access to all your messages and contacts.
|
||||
class SQLiteSession(Session):
|
||||
"""
|
||||
This session contains the required information to login into your
|
||||
Telegram account. NEVER give the saved session file to anyone, since
|
||||
they would gain instant access to all your messages and contacts.
|
||||
|
||||
If you think the session has been compromised, close all the sessions
|
||||
through an official Telegram client to revoke the authorization.
|
||||
If you think the session has been compromised, close all the sessions
|
||||
through an official Telegram client to revoke the authorization.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id=None):
|
||||
|
@ -53,66 +56,13 @@ class SQLiteSession(MemorySession):
|
|||
c.execute("delete from version")
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
self.save()
|
||||
|
||||
# These values will be saved
|
||||
c.execute('select * from sessions')
|
||||
tuple_ = c.fetchone()
|
||||
if tuple_:
|
||||
self._dc_id, self._server_address, self._port, key, \
|
||||
self._takeout_id = tuple_
|
||||
self._auth_key = AuthKey(data=key)
|
||||
|
||||
c.close()
|
||||
else:
|
||||
# Tables don't exist, create new ones
|
||||
self._create_table(
|
||||
c,
|
||||
"version (version integer primary key)"
|
||||
,
|
||||
"""sessions (
|
||||
dc_id integer primary key,
|
||||
server_address text,
|
||||
port integer,
|
||||
auth_key blob,
|
||||
takeout_id integer
|
||||
)"""
|
||||
,
|
||||
"""entities (
|
||||
id integer primary key,
|
||||
hash integer not null,
|
||||
username text,
|
||||
phone integer,
|
||||
name text,
|
||||
date integer
|
||||
)"""
|
||||
,
|
||||
"""sent_files (
|
||||
md5_digest blob,
|
||||
file_size integer,
|
||||
type integer,
|
||||
id integer,
|
||||
hash integer,
|
||||
primary key(md5_digest, file_size, type)
|
||||
)"""
|
||||
,
|
||||
"""update_state (
|
||||
id integer primary key,
|
||||
pts integer,
|
||||
qts integer,
|
||||
date integer,
|
||||
seq integer
|
||||
)"""
|
||||
)
|
||||
self._mk_tables(c)
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
self._update_session_table()
|
||||
c.close()
|
||||
self.save()
|
||||
|
||||
def clone(self, to_instance=None):
|
||||
cloned = super().clone(to_instance)
|
||||
cloned.save_entities = self.save_entities
|
||||
return cloned
|
||||
|
||||
def _upgrade_database(self, old):
|
||||
c = self._cursor()
|
||||
if old == 1:
|
||||
|
@ -150,75 +100,164 @@ class SQLiteSession(MemorySession):
|
|||
if old == 6:
|
||||
old += 1
|
||||
c.execute("alter table entities add column date integer")
|
||||
if old == 7:
|
||||
self._mk_tables(c)
|
||||
c.execute('''
|
||||
insert into datacenter (id, ip, port, auth)
|
||||
select dc_id, server_address, port, auth_key
|
||||
from sessions
|
||||
''')
|
||||
c.execute('''
|
||||
insert into session (user_id, dc_id, bot, pts, qts, date, seq, takeout_id)
|
||||
select
|
||||
0,
|
||||
s.dc_id,
|
||||
0,
|
||||
coalesce(u.pts, 0),
|
||||
coalesce(u.qts, 0),
|
||||
coalesce(u.date, 0),
|
||||
coalesce(u.seq, 0),
|
||||
s.takeout_id
|
||||
from sessions s
|
||||
left join update_state u on u.id = 0
|
||||
limit 1
|
||||
''')
|
||||
c.execute('''
|
||||
insert into entity (id, access_hash, ty)
|
||||
select
|
||||
case
|
||||
when id < -1000000000000 then -(id + 1000000000000)
|
||||
when id < 0 then -id
|
||||
else id
|
||||
end,
|
||||
hash,
|
||||
case
|
||||
when id < -1000000000000 then 67
|
||||
when id < 0 then 71
|
||||
else 85
|
||||
end
|
||||
from entities
|
||||
''')
|
||||
c.execute('drop table sessions')
|
||||
c.execute('drop table entities')
|
||||
c.execute('drop table sent_files')
|
||||
c.execute('drop table update_state')
|
||||
|
||||
c.close()
|
||||
def _mk_tables(self, c):
|
||||
self._create_table(
|
||||
c,
|
||||
'''version (
|
||||
version integer primary key
|
||||
)''',
|
||||
'''datacenter (
|
||||
id integer primary key,
|
||||
ip text not null,
|
||||
port integer not null,
|
||||
auth blob not null
|
||||
)''',
|
||||
'''session (
|
||||
user_id integer primary key,
|
||||
dc_id integer not null,
|
||||
bot integer not null,
|
||||
pts integer not null,
|
||||
qts integer not null,
|
||||
date integer not null,
|
||||
seq integer not null,
|
||||
takeout_id integer
|
||||
)''',
|
||||
'''channel (
|
||||
channel_id integer primary key,
|
||||
pts integer not null
|
||||
)''',
|
||||
'''entity (
|
||||
id integer primary key,
|
||||
access_hash integer not null,
|
||||
ty integer not null
|
||||
)''',
|
||||
)
|
||||
|
||||
async def insert_dc(self, dc: DataCenter):
|
||||
self._execute(
|
||||
'insert or replace into datacenter values (?,?,?,?)',
|
||||
dc.id,
|
||||
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
|
||||
dc.port,
|
||||
dc.auth
|
||||
)
|
||||
|
||||
async def get_all_dc(self) -> List[DataCenter]:
|
||||
c = self._cursor()
|
||||
res = []
|
||||
for (id, ip, port, auth) in c.execute('select * from datacenter'):
|
||||
ip = ipaddress.ip_address(ip)
|
||||
res.append(DataCenter(
|
||||
id=id,
|
||||
ipv4=int(ip) if ip.version == 4 else None,
|
||||
ipv6=int(ip) if ip.version == 6 else None,
|
||||
port=port,
|
||||
auth=auth,
|
||||
))
|
||||
return res
|
||||
|
||||
async def set_state(self, state: SessionState):
|
||||
self._execute(
|
||||
'insert or replace into session values (?,?,?,?,?,?,?,?)',
|
||||
state.user_id,
|
||||
state.dc_id,
|
||||
int(state.bot),
|
||||
state.pts,
|
||||
state.qts,
|
||||
state.date,
|
||||
state.seq,
|
||||
state.takeout_id,
|
||||
)
|
||||
|
||||
async def get_state(self) -> Optional[SessionState]:
|
||||
row = self._execute('select * from session')
|
||||
return SessionState(*row) if row else None
|
||||
|
||||
async def insert_channel_state(self, state: ChannelState):
|
||||
self._execute(
|
||||
'insert or replace into channel values (?,?)',
|
||||
state.channel_id,
|
||||
state.pts,
|
||||
)
|
||||
|
||||
async def get_all_channel_states(self) -> List[ChannelState]:
|
||||
c = self._cursor()
|
||||
try:
|
||||
return [
|
||||
ChannelState(*row)
|
||||
for row in c.execute('select * from channel')
|
||||
]
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def insert_entities(self, entities: List[Entity]):
|
||||
c = self._cursor()
|
||||
try:
|
||||
c.executemany(
|
||||
'insert or replace into entity values (?,?,?)',
|
||||
[(e.id, e.access_hash, e.ty) for e in entities]
|
||||
)
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
|
||||
return Entity(*row) if row else None
|
||||
|
||||
async def save(self):
|
||||
# This is a no-op if there are no changes to commit, so there's
|
||||
# no need for us to keep track of an "unsaved changes" variable.
|
||||
if self._conn is not None:
|
||||
self._conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_table(c, *definitions):
|
||||
for definition in definitions:
|
||||
c.execute('create table {}'.format(definition))
|
||||
|
||||
# Data from sessions should be kept as properties
|
||||
# not to fetch the database every time we need it
|
||||
def set_dc(self, dc_id, server_address, port):
|
||||
super().set_dc(dc_id, server_address, port)
|
||||
self._update_session_table()
|
||||
|
||||
# Fetch the auth_key corresponding to this data center
|
||||
row = self._execute('select auth_key from sessions')
|
||||
if row and row[0]:
|
||||
self._auth_key = AuthKey(data=row[0])
|
||||
else:
|
||||
self._auth_key = None
|
||||
|
||||
@MemorySession.auth_key.setter
|
||||
def auth_key(self, value):
|
||||
self._auth_key = value
|
||||
self._update_session_table()
|
||||
|
||||
@MemorySession.takeout_id.setter
|
||||
def takeout_id(self, value):
|
||||
self._takeout_id = value
|
||||
self._update_session_table()
|
||||
|
||||
def _update_session_table(self):
|
||||
c = self._cursor()
|
||||
# While we can save multiple rows into the sessions table
|
||||
# currently we only want to keep ONE as the tables don't
|
||||
# tell us which auth_key's are usable and will work. Needs
|
||||
# some more work before being able to save auth_key's for
|
||||
# multiple DCs. Probably done differently.
|
||||
c.execute('delete from sessions')
|
||||
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
|
||||
self._dc_id,
|
||||
self._server_address,
|
||||
self._port,
|
||||
self._auth_key.key if self._auth_key else b'',
|
||||
self._takeout_id
|
||||
))
|
||||
c.close()
|
||||
|
||||
def get_update_state(self, entity_id):
|
||||
row = self._execute('select pts, qts, date, seq from update_state '
|
||||
'where id = ?', entity_id)
|
||||
if row:
|
||||
pts, qts, date, seq = row
|
||||
date = datetime.datetime.fromtimestamp(
|
||||
date, tz=datetime.timezone.utc)
|
||||
return _tl.updates.State(pts, qts, date, seq, unread_count=0)
|
||||
|
||||
def set_update_state(self, entity_id, state):
|
||||
self._execute('insert or replace into update_state values (?,?,?,?,?)',
|
||||
entity_id, state.pts, state.qts,
|
||||
state.date.timestamp(), state.seq)
|
||||
|
||||
def save(self):
|
||||
"""Saves the current session object as session_user_id.session"""
|
||||
# This is a no-op if there are no changes to commit, so there's
|
||||
# no need for us to keep track of an "unsaved changes" variable.
|
||||
if self._conn is not None:
|
||||
self._conn.commit()
|
||||
|
||||
def _cursor(self):
|
||||
"""Asserts that the connection is open and returns a cursor"""
|
||||
if self._conn is None:
|
||||
|
@ -236,108 +275,3 @@ class SQLiteSession(MemorySession):
|
|||
return c.execute(stmt, values).fetchone()
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection unless we're working in-memory"""
|
||||
if self.filename != ':memory:':
|
||||
if self._conn is not None:
|
||||
self._conn.commit()
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def delete(self):
|
||||
"""Deletes the current session file"""
|
||||
if self.filename == ':memory:':
|
||||
return True
|
||||
try:
|
||||
os.remove(self.filename)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def list_sessions(cls):
|
||||
"""Lists all the sessions of the users who have ever connected
|
||||
using this client and never logged out
|
||||
"""
|
||||
return [os.path.splitext(os.path.basename(f))[0]
|
||||
for f in os.listdir('.') if f.endswith(EXTENSION)]
|
||||
|
||||
# Entity processing
|
||||
|
||||
def process_entities(self, tlo):
|
||||
"""
|
||||
Processes all the found entities on the given TLObject,
|
||||
unless .save_entities is False.
|
||||
"""
|
||||
if not self.save_entities:
|
||||
return
|
||||
|
||||
rows = self._entities_to_rows(tlo)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
c = self._cursor()
|
||||
try:
|
||||
now_tup = (int(time.time()),)
|
||||
rows = [row + now_tup for row in rows]
|
||||
c.executemany(
|
||||
'insert or replace into entities values (?,?,?,?,?,?)', rows)
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def get_entity_rows_by_phone(self, phone):
|
||||
return self._execute(
|
||||
'select id, hash from entities where phone = ?', phone)
|
||||
|
||||
def get_entity_rows_by_username(self, username):
|
||||
c = self._cursor()
|
||||
try:
|
||||
results = c.execute(
|
||||
'select id, hash, date from entities where username = ?',
|
||||
(username,)
|
||||
).fetchall()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# If there is more than one result for the same username, evict the oldest one
|
||||
if len(results) > 1:
|
||||
results.sort(key=lambda t: t[2] or 0)
|
||||
c.executemany('update entities set username = null where id = ?',
|
||||
[(t[0],) for t in results[:-1]])
|
||||
|
||||
return results[-1][0], results[-1][1]
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def get_entity_rows_by_name(self, name):
|
||||
return self._execute(
|
||||
'select id, hash from entities where name = ?', name)
|
||||
|
||||
def get_entity_rows_by_id(self, id, exact=True):
|
||||
return self._execute(
|
||||
'select id, hash from entities where id = ?', id)
|
||||
|
||||
# File processing
|
||||
|
||||
def get_file(self, md5_digest, file_size, cls):
|
||||
row = self._execute(
|
||||
'select id, hash from sent_files '
|
||||
'where md5_digest = ? and file_size = ? and type = ?',
|
||||
md5_digest, file_size, _SentFileType.from_type(cls).value
|
||||
)
|
||||
if row:
|
||||
# Both allowed classes have (id, access_hash) as parameters
|
||||
return cls(row[0], row[1])
|
||||
|
||||
def cache_file(self, md5_digest, file_size, instance):
|
||||
if not isinstance(instance, (_tl.InputDocument, _tl.InputPhoto)):
|
||||
raise TypeError('Cannot cache %s instance' % type(instance))
|
||||
|
||||
self._execute(
|
||||
'insert or replace into sent_files values (?,?,?,?,?)',
|
||||
md5_digest, file_size,
|
||||
_SentFileType.from_type(type(instance)).value,
|
||||
instance.id, instance.access_hash
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ import struct
|
|||
|
||||
from .abstract import Session
|
||||
from .memory import MemorySession
|
||||
from .._crypto import AuthKey
|
||||
from .types import DataCenter, ChannelState, SessionState, Entity
|
||||
|
||||
_STRUCT_PREFORMAT = '>B{}sH256s'
|
||||
|
||||
|
@ -34,12 +34,33 @@ class StringSession(MemorySession):
|
|||
|
||||
string = string[1:]
|
||||
ip_len = 4 if len(string) == 352 else 16
|
||||
self._dc_id, ip, self._port, key = struct.unpack(
|
||||
dc_id, ip, port, key = struct.unpack(
|
||||
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
|
||||
|
||||
self._server_address = ipaddress.ip_address(ip).compressed
|
||||
if any(key):
|
||||
self._auth_key = AuthKey(key)
|
||||
self.state = SessionState(
|
||||
dc_id=dc_id,
|
||||
user_id=0,
|
||||
bot=False,
|
||||
pts=0,
|
||||
qts=0,
|
||||
date=0,
|
||||
seq=0,
|
||||
takeout_id=0
|
||||
)
|
||||
if ip_len == 4:
|
||||
ipv4 = int.from_bytes(ip, 'big', False)
|
||||
ipv6 = None
|
||||
else:
|
||||
ipv4 = None
|
||||
ipv6 = int.from_bytes(ip, 'big', signed=False)
|
||||
|
||||
self.dcs[dc_id] = DataCenter(
|
||||
id=dc_id,
|
||||
ipv4=ipv4,
|
||||
ipv6=ipv6,
|
||||
port=port,
|
||||
auth=key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode(x: bytes) -> str:
|
||||
|
@ -50,14 +71,18 @@ class StringSession(MemorySession):
|
|||
return base64.urlsafe_b64decode(x)
|
||||
|
||||
def save(self: Session):
|
||||
if not self.auth_key:
|
||||
if not self.state:
|
||||
return ''
|
||||
|
||||
ip = ipaddress.ip_address(self.server_address).packed
|
||||
if self.state.ipv6 is not None:
|
||||
ip = self.state.ipv6.to_bytes(16, 'big', signed=False)
|
||||
else:
|
||||
ip = self.state.ipv6.to_bytes(4, 'big', signed=False)
|
||||
|
||||
return CURRENT_VERSION + StringSession.encode(struct.pack(
|
||||
_STRUCT_PREFORMAT.format(len(ip)),
|
||||
self.dc_id,
|
||||
self.state.dc_id,
|
||||
ip,
|
||||
self.port,
|
||||
self.auth_key.key
|
||||
self.state.port,
|
||||
self.dcs[self.state.dc_id].auth
|
||||
))
|
||||
|
|
Loading…
Reference in New Issue
Block a user