Completely overhaul sessions

This commit is contained in:
Lonami Exo 2021-09-19 13:45:19 +02:00
parent 0b54fa7a25
commit 684f640b60
5 changed files with 339 additions and 574 deletions

View File

@ -73,6 +73,72 @@ removed. This implies:
// TODO provide standalone alternative for this? // TODO provide standalone alternative for this?
Complete overhaul of session files
----------------------------------
If you were using third-party libraries to deal with sessions, you will need to wait for those to
be updated. The library will automatically upgrade the SQLite session files to the new version,
and the ``StringSession`` remains backward-compatible. The sessions can now be async.
In case you were relying on the tables used by SQLite (even though these should have been, and
will still need to be, treated as an implementation detail), here are the changes:
* The ``sessions`` table is now correctly split into ``datacenter`` and ``session``.
``datacenter`` contains information about a Telegram datacenter, along with its corresponding
authorization key, and ``session`` contains information about the update state and user.
* The ``entities`` table is now called ``entity`` and stores the ``type`` separatedly.
* The ``update_state`` table is now split into ``session`` and ``channel``, which can contain
a per-channel ``pts``.
Because **the new version does not cache usernames, phone numbers and display names**, using these
in method calls is now quite expensive. You *should* migrate your code to do the Right Thing and
start using identifiers rather than usernames, phone numbers or invite links. This is both simpler
and more reliable, because while a user identifier won't change, their username could.
You can use the following snippet to make a JSON backup (alternatively, you could just copy the
``.session`` file and keep it around) in case you want to preserve the cached usernames:
.. code-block:: python
import sqlite, json
with sqlite3.connect('your.session') as conn, open('entities.json', 'w', encoding='utf-8') as fp:
json.dump([
{'id': id, 'hash': hash, 'username': username, 'phone': phone, 'name': name, 'date': date}
for (id, hash, username, phone, name, date)
in conn.execute('select id, hash, username, phone, name, date from entities')
], fp)
The following public methods or properties have also been removed from ``SQLiteSession`` because
they no longer make sense:
* ``list_sessions``. You can ``glob.glob('*.session')`` instead.
* ``clone``.
And the following, which were inherited from ``MemorySession``:
* ``delete``. You can ``os.remove`` the file instead (preferably after ``client.log_out()``).
* ``set_dc``.
* ``dc_id``.
* ``server_address``.
* ``port``.
* ``auth_key``.
* ``takeout_id``.
* ``get_update_state``.
* ``set_update_state``.
* ``process_entities``.
* ``get_entity_rows_by_phone``.
* ``get_entity_rows_by_username``.
* ``get_entity_rows_by_name``.
* ``get_entity_rows_by_id``.
* ``get_input_entity``.
* ``cache_file``.
* ``get_file``.
You also can no longer set ``client.session.save_entities = False``. The entities must be saved
for the library to work properly. If you still don't want it, you should subclass the session and
override the methods to do nothing.
The "iter" variant of the client methods have been removed The "iter" variant of the client methods have been removed
---------------------------------------------------------- ----------------------------------------------------------

View File

@ -1,167 +1,90 @@
from .types import DataCenter, ChannelState, SessionState, Entity
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional
class Session(ABC): class Session(ABC):
def __init__(self):
pass
def clone(self, to_instance=None):
"""
Creates a clone of this session file.
"""
return to_instance or self.__class__()
@abstractmethod @abstractmethod
def set_dc(self, dc_id, server_address, port): async def insert_dc(self, dc: DataCenter):
""" """
Sets the information of the data center address and port that Store a new or update an existing `DataCenter` with matching ``id``.
the library should connect to, as well as the data center ID,
which is currently unused.
"""
raise NotImplementedError
@property
@abstractmethod
def dc_id(self):
"""
Returns the currently-used data center ID.
"""
raise NotImplementedError
@property
@abstractmethod
def server_address(self):
"""
Returns the server address where the library should connect to.
"""
raise NotImplementedError
@property
@abstractmethod
def port(self):
"""
Returns the port to which the library should connect to.
"""
raise NotImplementedError
@property
@abstractmethod
def auth_key(self):
"""
Returns an ``AuthKey`` instance associated with the saved
data center, or `None` if a new one should be generated.
"""
raise NotImplementedError
@auth_key.setter
@abstractmethod
def auth_key(self, value):
"""
Sets the ``AuthKey`` to be used for the saved data center.
"""
raise NotImplementedError
@property
@abstractmethod
def takeout_id(self):
"""
Returns an ID of the takeout process initialized for this session,
or `None` if there's no were any unfinished takeout requests.
"""
raise NotImplementedError
@takeout_id.setter
@abstractmethod
def takeout_id(self, value):
"""
Sets the ID of the unfinished takeout process for this session.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_update_state(self, entity_id): async def get_all_dc(self) -> List[DataCenter]:
""" """
Returns the ``UpdateState`` associated with the given `entity_id`. Get a list of all currently-stored `DataCenter`. Should not contain duplicate ``id``.
If the `entity_id` is 0, it should return the ``UpdateState`` for
no specific channel (the "general" state). If no state is known
it should ``return None``.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def set_update_state(self, entity_id, state): async def set_state(self, state: SessionState):
""" """
Sets the given ``UpdateState`` for the specified `entity_id`, which Set the state about the current session.
should be 0 if the ``UpdateState`` is the "general" state (and not
for any specific channel).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def close(self): async def get_state(self) -> Optional[SessionState]:
""" """
Called on client disconnection. Should be used to Get the state about the current session.
free any used resources. Can be left empty if none.
"""
@abstractmethod
def save(self):
"""
Called whenever important properties change. It should
make persist the relevant session information to disk.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete(self): async def insert_channel_state(self, state: ChannelState):
""" """
Called upon client.log_out(). Should delete the stored Store a new or update an existing `ChannelState` with matching ``id``.
information from disk since it's not valid anymore.
"""
raise NotImplementedError
@classmethod
def list_sessions(cls):
"""
Lists available sessions. Not used by the library itself.
"""
return []
@abstractmethod
def process_entities(self, tlo):
"""
Processes the input ``TLObject`` or ``list`` and saves
whatever information is relevant (e.g., ID or access hash).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_input_entity(self, key): async def get_all_channel_states(self) -> List[ChannelState]:
""" """
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). Get a list of all currently-stored `ChannelState`. Should not contain duplicate ``id``.
The library uses this method whenever an ``InputPeer`` is needed
to suit several purposes (e.g. user only provided its ID or wishes
to use a cached username to avoid extra RPC).
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def cache_file(self, md5_digest, file_size, instance): async def insert_entities(self, entities: List[Entity]):
""" """
Caches the given file information persistently, so that it Store new or update existing `Entity` with matching ``id``.
doesn't need to be re-uploaded in case the file is used again.
The ``instance`` will be either an ``InputPhoto`` or ``InputDocument``, Entities should be saved on a best-effort. It is okay to not save them, although the
both with an ``.id`` and ``.access_hash`` attributes. library may need to do extra work if a previously-saved entity is missing, or even be
unable to continue without the entity.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_file(self, md5_digest, file_size, cls): async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
""" """
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size`` Get the `Entity` with matching ``ty`` and ``id``.
match an existing saved record. The class will either be an
``InputPhoto`` or ``InputDocument``, both with two parameters The following groups of ``ty`` should be treated to be equivalent, that is, for a given
``id`` and ``access_hash`` in that order. ``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
that ``id`` from within any ``ty`` in that group should be returned.
* ``'U'`` and ``'B'`` (user and bot).
* ``'G'`` (small group chat).
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel).
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type.
"""
raise NotImplementedError
@abstractmethod
async def save(self):
"""
Save the session.
May do nothing if the other methods already saved when they were called.
May return custom data when manual saving is intended.
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -1,230 +1,47 @@
from enum import Enum from .types import DataCenter, ChannelState, SessionState, Entity
from .abstract import Session from .abstract import Session
from .._misc import utils, tlobject from .._misc import utils, tlobject
from .. import _tl from .. import _tl
from typing import List, Optional
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == _tl.InputDocument:
return _SentFileType.DOCUMENT
elif cls == _tl.InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class MemorySession(Session): class MemorySession(Session):
__slots__ = ('dcs', 'state', 'channel_states', 'entities')
def __init__(self): def __init__(self):
super().__init__() self.dcs = {}
self.state = None
self.channel_states = {}
self.entities = {}
self._dc_id = 0 async def insert_dc(self, dc: DataCenter):
self._server_address = None self.dcs[dc.id] = dc
self._port = None
self._auth_key = None
self._takeout_id = None
self._files = {} async def get_all_dc(self) -> List[DataCenter]:
self._entities = set() return list(self.dcs.values())
self._update_states = {}
def set_dc(self, dc_id, server_address, port): async def set_state(self, state: SessionState):
self._dc_id = dc_id or 0 self.state = state
self._server_address = server_address
self._port = port
@property async def get_state(self) -> Optional[SessionState]:
def dc_id(self): return self.state
return self._dc_id
@property async def insert_channel_state(self, state: ChannelState):
def server_address(self): self.channel_states[state.channel_id] = state
return self._server_address
@property async def get_all_channel_states(self) -> List[ChannelState]:
def port(self): return list(self.channel_states.values())
return self._port
@property async def insert_entities(self, entities: List[Entity]):
def auth_key(self): self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
return self._auth_key
@auth_key.setter async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
def auth_key(self, value):
self._auth_key = value
@property
def takeout_id(self):
return self._takeout_id
@takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)
def set_update_state(self, entity_id, state):
self._update_states[entity_id] = state
def close(self):
pass
def save(self):
pass
def delete(self):
pass
@staticmethod
def _entity_values_to_row(id, hash, username, phone, name):
# While this is a simple implementation it might be overrode by,
# other classes so they don't need to implement the plural form
# of the method. Don't remove.
return id, hash, username, phone, name
def _entity_to_row(self, e):
if not isinstance(e, tlobject.TLObject):
return
try: try:
p = utils.get_input_peer(e, allow_self=False) ty, access_hash = self.entities[id]
marked_id = utils.get_peer_id(p) return Entity(ty, id, access_hash)
except TypeError:
# Note: `get_input_peer` already checks for non-zero `access_hash`.
# See issues #354 and #392. It also checks that the entity
# is not `min`, because its `access_hash` cannot be used
# anywhere (since layer 102, there are two access hashes).
return
if isinstance(p, (_tl.InputPeerUser, _tl.InputPeerChannel)):
p_hash = p.access_hash
elif isinstance(p, _tl.InputPeerChat):
p_hash = 0
else:
return
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
return self._entity_values_to_row(
marked_id, p_hash, username, phone, name
)
def _entities_to_rows(self, tlo):
if not isinstance(tlo, tlobject.TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'user'):
entities.append(tlo.user)
if hasattr(tlo, 'chat'):
entities.append(tlo.chat)
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
row = self._entity_to_row(e)
if row:
rows.append(row)
return rows
def process_entities(self, tlo):
self._entities |= set(self._entities_to_rows(tlo))
def get_entity_rows_by_phone(self, phone):
try:
return next((id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone)
except StopIteration:
pass
def get_entity_rows_by_username(self, username):
try:
return next((id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username)
except StopIteration:
pass
def get_entity_rows_by_name(self, name):
try:
return next((id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name)
except StopIteration:
pass
def get_entity_rows_by_id(self, id, exact=True):
try:
return next((id, hash) for found_id, hash, _, _, _
in self._entities if found_id == id)
except StopIteration:
pass
def get_input_entity(self, key):
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, tlobject.TLObject):
key = utils.get_peer_id(key)
exact = True
else:
exact = not isinstance(key, int) or key < 0
result = None
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
result = self.get_entity_rows_by_phone(phone)
else:
username, invite = utils.parse_username(key)
if username and not invite:
result = self.get_entity_rows_by_username(username)
elif isinstance(key, int):
result = self.get_entity_rows_by_id(key, exact)
if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key)
if result:
entity_id, entity_hash = result # unpack resulting tuple
entity_id, kind = utils.resolve_id(entity_id)
# removes the mark and returns type of entity
if kind == _tl.PeerUser:
return _tl.InputPeerUser(entity_id, entity_hash)
elif kind == _tl.PeerChat:
return _tl.InputPeerChat(entity_id)
elif kind == _tl.PeerChannel:
return _tl.InputPeerChannel(entity_id, entity_hash)
else:
raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (_tl.InputDocument, _tl.InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(type(instance)))
value = (instance.id, instance.access_hash)
self._files[key] = value
def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls))
try:
return cls(*self._files[key])
except KeyError: except KeyError:
return None return None
async def save(self):
pass

View File

@ -1,11 +1,13 @@
import datetime import datetime
import os import os
import time import time
import ipaddress
from typing import Optional, List
from .memory import MemorySession, _SentFileType from .abstract import Session
from .._misc import utils from .._misc import utils
from .. import _tl from .. import _tl
from .._crypto import AuthKey from .types import DataCenter, ChannelState, SessionState, Entity
try: try:
import sqlite3 import sqlite3
@ -15,16 +17,17 @@ except ImportError as e:
sqlite3_err = type(e) sqlite3_err = type(e)
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 7 # database version CURRENT_VERSION = 8 # database version
class SQLiteSession(MemorySession): class SQLiteSession(Session):
"""This session contains the required information to login into your """
Telegram account. NEVER give the saved session file to anyone, since This session contains the required information to login into your
they would gain instant access to all your messages and contacts. Telegram account. NEVER give the saved session file to anyone, since
they would gain instant access to all your messages and contacts.
If you think the session has been compromised, close all the sessions If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization. through an official Telegram client to revoke the authorization.
""" """
def __init__(self, session_id=None): def __init__(self, session_id=None):
@ -53,66 +56,13 @@ class SQLiteSession(MemorySession):
c.execute("delete from version") c.execute("delete from version")
c.execute("insert into version values (?)", (CURRENT_VERSION,)) c.execute("insert into version values (?)", (CURRENT_VERSION,))
self.save() self.save()
# These values will be saved
c.execute('select * from sessions')
tuple_ = c.fetchone()
if tuple_:
self._dc_id, self._server_address, self._port, key, \
self._takeout_id = tuple_
self._auth_key = AuthKey(data=key)
c.close()
else: else:
# Tables don't exist, create new ones # Tables don't exist, create new ones
self._create_table( self._mk_tables(c)
c,
"version (version integer primary key)"
,
"""sessions (
dc_id integer primary key,
server_address text,
port integer,
auth_key blob,
takeout_id integer
)"""
,
"""entities (
id integer primary key,
hash integer not null,
username text,
phone integer,
name text,
date integer
)"""
,
"""sent_files (
md5_digest blob,
file_size integer,
type integer,
id integer,
hash integer,
primary key(md5_digest, file_size, type)
)"""
,
"""update_state (
id integer primary key,
pts integer,
qts integer,
date integer,
seq integer
)"""
)
c.execute("insert into version values (?)", (CURRENT_VERSION,)) c.execute("insert into version values (?)", (CURRENT_VERSION,))
self._update_session_table()
c.close() c.close()
self.save() self.save()
def clone(self, to_instance=None):
cloned = super().clone(to_instance)
cloned.save_entities = self.save_entities
return cloned
def _upgrade_database(self, old): def _upgrade_database(self, old):
c = self._cursor() c = self._cursor()
if old == 1: if old == 1:
@ -150,75 +100,164 @@ class SQLiteSession(MemorySession):
if old == 6: if old == 6:
old += 1 old += 1
c.execute("alter table entities add column date integer") c.execute("alter table entities add column date integer")
if old == 7:
self._mk_tables(c)
c.execute('''
insert into datacenter (id, ip, port, auth)
select dc_id, server_address, port, auth_key
from sessions
''')
c.execute('''
insert into session (user_id, dc_id, bot, pts, qts, date, seq, takeout_id)
select
0,
s.dc_id,
0,
coalesce(u.pts, 0),
coalesce(u.qts, 0),
coalesce(u.date, 0),
coalesce(u.seq, 0),
s.takeout_id
from sessions s
left join update_state u on u.id = 0
limit 1
''')
c.execute('''
insert into entity (id, access_hash, ty)
select
case
when id < -1000000000000 then -(id + 1000000000000)
when id < 0 then -id
else id
end,
hash,
case
when id < -1000000000000 then 67
when id < 0 then 71
else 85
end
from entities
''')
c.execute('drop table sessions')
c.execute('drop table entities')
c.execute('drop table sent_files')
c.execute('drop table update_state')
c.close() def _mk_tables(self, c):
self._create_table(
c,
'''version (
version integer primary key
)''',
'''datacenter (
id integer primary key,
ip text not null,
port integer not null,
auth blob not null
)''',
'''session (
user_id integer primary key,
dc_id integer not null,
bot integer not null,
pts integer not null,
qts integer not null,
date integer not null,
seq integer not null,
takeout_id integer
)''',
'''channel (
channel_id integer primary key,
pts integer not null
)''',
'''entity (
id integer primary key,
access_hash integer not null,
ty integer not null
)''',
)
async def insert_dc(self, dc: DataCenter):
self._execute(
'insert or replace into datacenter values (?,?,?,?)',
dc.id,
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
dc.port,
dc.auth
)
async def get_all_dc(self) -> List[DataCenter]:
c = self._cursor()
res = []
for (id, ip, port, auth) in c.execute('select * from datacenter'):
ip = ipaddress.ip_address(ip)
res.append(DataCenter(
id=id,
ipv4=int(ip) if ip.version == 4 else None,
ipv6=int(ip) if ip.version == 6 else None,
port=port,
auth=auth,
))
return res
async def set_state(self, state: SessionState):
self._execute(
'insert or replace into session values (?,?,?,?,?,?,?,?)',
state.user_id,
state.dc_id,
int(state.bot),
state.pts,
state.qts,
state.date,
state.seq,
state.takeout_id,
)
async def get_state(self) -> Optional[SessionState]:
row = self._execute('select * from session')
return SessionState(*row) if row else None
async def insert_channel_state(self, state: ChannelState):
self._execute(
'insert or replace into channel values (?,?)',
state.channel_id,
state.pts,
)
async def get_all_channel_states(self) -> List[ChannelState]:
c = self._cursor()
try:
return [
ChannelState(*row)
for row in c.execute('select * from channel')
]
finally:
c.close()
async def insert_entities(self, entities: List[Entity]):
c = self._cursor()
try:
c.executemany(
'insert or replace into entity values (?,?,?)',
[(e.id, e.access_hash, e.ty) for e in entities]
)
finally:
c.close()
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
return Entity(*row) if row else None
async def save(self):
# This is a no-op if there are no changes to commit, so there's
# no need for us to keep track of an "unsaved changes" variable.
if self._conn is not None:
self._conn.commit()
@staticmethod @staticmethod
def _create_table(c, *definitions): def _create_table(c, *definitions):
for definition in definitions: for definition in definitions:
c.execute('create table {}'.format(definition)) c.execute('create table {}'.format(definition))
# Data from sessions should be kept as properties
# not to fetch the database every time we need it
def set_dc(self, dc_id, server_address, port):
super().set_dc(dc_id, server_address, port)
self._update_session_table()
# Fetch the auth_key corresponding to this data center
row = self._execute('select auth_key from sessions')
if row and row[0]:
self._auth_key = AuthKey(data=row[0])
else:
self._auth_key = None
@MemorySession.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@MemorySession.takeout_id.setter
def takeout_id(self, value):
self._takeout_id = value
self._update_session_table()
def _update_session_table(self):
c = self._cursor()
# While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't
# tell us which auth_key's are usable and will work. Needs
# some more work before being able to save auth_key's for
# multiple DCs. Probably done differently.
c.execute('delete from sessions')
c.execute('insert or replace into sessions values (?,?,?,?,?)', (
self._dc_id,
self._server_address,
self._port,
self._auth_key.key if self._auth_key else b'',
self._takeout_id
))
c.close()
def get_update_state(self, entity_id):
row = self._execute('select pts, qts, date, seq from update_state '
'where id = ?', entity_id)
if row:
pts, qts, date, seq = row
date = datetime.datetime.fromtimestamp(
date, tz=datetime.timezone.utc)
return _tl.updates.State(pts, qts, date, seq, unread_count=0)
def set_update_state(self, entity_id, state):
self._execute('insert or replace into update_state values (?,?,?,?,?)',
entity_id, state.pts, state.qts,
state.date.timestamp(), state.seq)
def save(self):
"""Saves the current session object as session_user_id.session"""
# This is a no-op if there are no changes to commit, so there's
# no need for us to keep track of an "unsaved changes" variable.
if self._conn is not None:
self._conn.commit()
def _cursor(self): def _cursor(self):
"""Asserts that the connection is open and returns a cursor""" """Asserts that the connection is open and returns a cursor"""
if self._conn is None: if self._conn is None:
@ -236,108 +275,3 @@ class SQLiteSession(MemorySession):
return c.execute(stmt, values).fetchone() return c.execute(stmt, values).fetchone()
finally: finally:
c.close() c.close()
def close(self):
"""Closes the connection unless we're working in-memory"""
if self.filename != ':memory:':
if self._conn is not None:
self._conn.commit()
self._conn.close()
self._conn = None
def delete(self):
"""Deletes the current session file"""
if self.filename == ':memory:':
return True
try:
os.remove(self.filename)
return True
except OSError:
return False
@classmethod
def list_sessions(cls):
"""Lists all the sessions of the users who have ever connected
using this client and never logged out
"""
return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith(EXTENSION)]
# Entity processing
def process_entities(self, tlo):
"""
Processes all the found entities on the given TLObject,
unless .save_entities is False.
"""
if not self.save_entities:
return
rows = self._entities_to_rows(tlo)
if not rows:
return
c = self._cursor()
try:
now_tup = (int(time.time()),)
rows = [row + now_tup for row in rows]
c.executemany(
'insert or replace into entities values (?,?,?,?,?,?)', rows)
finally:
c.close()
def get_entity_rows_by_phone(self, phone):
return self._execute(
'select id, hash from entities where phone = ?', phone)
def get_entity_rows_by_username(self, username):
c = self._cursor()
try:
results = c.execute(
'select id, hash, date from entities where username = ?',
(username,)
).fetchall()
if not results:
return None
# If there is more than one result for the same username, evict the oldest one
if len(results) > 1:
results.sort(key=lambda t: t[2] or 0)
c.executemany('update entities set username = null where id = ?',
[(t[0],) for t in results[:-1]])
return results[-1][0], results[-1][1]
finally:
c.close()
def get_entity_rows_by_name(self, name):
return self._execute(
'select id, hash from entities where name = ?', name)
def get_entity_rows_by_id(self, id, exact=True):
return self._execute(
'select id, hash from entities where id = ?', id)
# File processing
def get_file(self, md5_digest, file_size, cls):
row = self._execute(
'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?',
md5_digest, file_size, _SentFileType.from_type(cls).value
)
if row:
# Both allowed classes have (id, access_hash) as parameters
return cls(row[0], row[1])
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (_tl.InputDocument, _tl.InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
self._execute(
'insert or replace into sent_files values (?,?,?,?,?)',
md5_digest, file_size,
_SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash
)

View File

@ -4,7 +4,7 @@ import struct
from .abstract import Session from .abstract import Session
from .memory import MemorySession from .memory import MemorySession
from .._crypto import AuthKey from .types import DataCenter, ChannelState, SessionState, Entity
_STRUCT_PREFORMAT = '>B{}sH256s' _STRUCT_PREFORMAT = '>B{}sH256s'
@ -34,12 +34,33 @@ class StringSession(MemorySession):
string = string[1:] string = string[1:]
ip_len = 4 if len(string) == 352 else 16 ip_len = 4 if len(string) == 352 else 16
self._dc_id, ip, self._port, key = struct.unpack( dc_id, ip, port, key = struct.unpack(
_STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string)) _STRUCT_PREFORMAT.format(ip_len), StringSession.decode(string))
self._server_address = ipaddress.ip_address(ip).compressed self.state = SessionState(
if any(key): dc_id=dc_id,
self._auth_key = AuthKey(key) user_id=0,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=0
)
if ip_len == 4:
ipv4 = int.from_bytes(ip, 'big', False)
ipv6 = None
else:
ipv4 = None
ipv6 = int.from_bytes(ip, 'big', signed=False)
self.dcs[dc_id] = DataCenter(
id=dc_id,
ipv4=ipv4,
ipv6=ipv6,
port=port,
auth=key
)
@staticmethod @staticmethod
def encode(x: bytes) -> str: def encode(x: bytes) -> str:
@ -50,14 +71,18 @@ class StringSession(MemorySession):
return base64.urlsafe_b64decode(x) return base64.urlsafe_b64decode(x)
def save(self: Session): def save(self: Session):
if not self.auth_key: if not self.state:
return '' return ''
ip = ipaddress.ip_address(self.server_address).packed if self.state.ipv6 is not None:
ip = self.state.ipv6.to_bytes(16, 'big', signed=False)
else:
ip = self.state.ipv6.to_bytes(4, 'big', signed=False)
return CURRENT_VERSION + StringSession.encode(struct.pack( return CURRENT_VERSION + StringSession.encode(struct.pack(
_STRUCT_PREFORMAT.format(len(ip)), _STRUCT_PREFORMAT.format(len(ip)),
self.dc_id, self.state.dc_id,
ip, ip,
self.port, self.state.port,
self.auth_key.key self.dcs[self.state.dc_id].auth
)) ))