Add session storages

This commit is contained in:
Lonami Exo 2023-09-10 19:54:05 +02:00
parent 16de3b274c
commit aa83e7b043
14 changed files with 484 additions and 194 deletions

View File

@ -141,7 +141,4 @@ async def check_password(
async def sign_out(self: Client) -> None:
await self(functions.auth.log_out())
def session(client: Client) -> Session:
client._config.session.state = client._message_box.session_state()
return client._config.session
await self._storage.delete()

View File

@ -19,13 +19,20 @@ from typing import (
)
from ...mtsender import Sender
from ...session import ChatHashCache, MessageBox, PackedChat, Session
from ...session import (
ChatHashCache,
MemorySession,
MessageBox,
PackedChat,
Session,
SqliteSession,
Storage,
)
from ...tl import Request, abcs
from ..events import Event
from ..events.filters import Filter
from ..types import (
AsyncList,
Chat,
ChatLike,
File,
InFileLike,
@ -41,15 +48,12 @@ from .auth import (
check_password,
is_authorized,
request_login_code,
session,
sign_in,
sign_out,
)
from .bots import InlineResult, inline_query
from .chats import (
get_participants,
)
from .dialogs import get_dialogs, delete_dialog
from .chats import get_participants
from .dialogs import delete_dialog, get_dialogs
from .files import (
download,
iter_download,
@ -88,38 +92,49 @@ from .updates import (
remove_event_handler,
set_handler_filter,
)
from .users import (
get_me,
input_to_peer,
resolve_to_packed,
)
from .users import get_me, input_to_peer, resolve_to_packed
Return = TypeVar("Return")
T = TypeVar("T")
class Client:
def __init__(self, config: Config) -> None:
def __init__(
self,
session: Optional[Union[str, Path, Storage]],
api_id: int,
api_hash: Optional[str] = None,
) -> None:
self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
self._dc_id = DEFAULT_DC
self._config = config
if isinstance(session, Storage):
self._storage = session
elif session is None:
self._storage = MemorySession()
else:
self._storage = SqliteSession(session)
self._config = Config(
session=Session(),
api_id=api_id,
api_hash=api_hash or "",
)
self._message_box = MessageBox()
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[
Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
] = asyncio.Queue(maxsize=config.update_queue_limit or 0)
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
self._dispatcher: Optional[asyncio.Task[None]] = None
self._downloader_map = object()
self._handlers: Dict[
Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
] = {}
if self_user := config.session.user:
if self_user := self._config.session.user:
self._dc_id = self_user.dc
if config.catch_up and config.session.state:
self._message_box.load(config.session.state)
if self._config.catch_up and self._config.session.state:
self._message_box.load(self._config.session.state)
# ---
@ -450,15 +465,6 @@ class Client:
def connected(self) -> bool:
return connected(self)
@property
def session(self) -> Session:
"""
Up-to-date session state, useful for persisting it to storage.
Mutating the returned object may cause the library to misbehave.
"""
return session(self)
def _build_message_map(
self,
result: abcs.Updates,

View File

@ -122,6 +122,8 @@ async def connect(self: Client) -> None:
if self._sender:
return
if session := await self._storage.load():
self._config.session = session
self._sender = await connect_sender(self._dc_id, self._config)
if self._message_box.is_empty() and self._config.session.user:
@ -148,6 +150,9 @@ async def disconnect(self: Client) -> None:
await self._sender.disconnect()
self._sender = None
self._config.session.state = self._message_box.session_state()
await self._storage.save(self._config.session)
async def invoke_request(
client: Client,

View File

@ -3,18 +3,15 @@ from .message_box import (
BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
ChannelState,
DataCenter,
Gap,
MessageBox,
PossibleGap,
PrematureEndReason,
PtsInfo,
Session,
State,
UpdateState,
User,
)
from .session import ChannelState, DataCenter, Session, UpdateState, User
from .storage import MemorySession, SqliteSession, Storage
__all__ = [
"ChatHashCache",
@ -23,15 +20,18 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT",
"ChannelState",
"DataCenter",
"Gap",
"PossibleGap",
"PrematureEndReason",
"PtsInfo",
"Session",
"State",
"ChannelState",
"DataCenter",
"Session",
"UpdateState",
"User",
"MessageBox",
"MemorySession",
"SqliteSession",
"Storage",
]

View File

@ -2,16 +2,11 @@ from .defs import (
BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
ChannelState,
DataCenter,
Gap,
PossibleGap,
PrematureEndReason,
PtsInfo,
Session,
State,
UpdateState,
User,
)
from .messagebox import MessageBox
@ -19,15 +14,10 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT",
"ChannelState",
"DataCenter",
"Gap",
"PossibleGap",
"PrematureEndReason",
"PtsInfo",
"Session",
"State",
"UpdateState",
"User",
"MessageBox",
]

View File

@ -1,144 +1,10 @@
import base64
import logging
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Self, Union
from typing import List, Literal, Union
from ...tl import abcs
class DataCenter:
__slots__ = ("id", "addr", "auth")
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
self.id = id
self.addr = addr
self.auth = auth
class User:
__slots__ = ("id", "dc", "bot")
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
self.id = id
self.dc = dc
self.bot = bot
class ChannelState:
__slots__ = ("id", "pts")
def __init__(self, *, id: int, pts: int) -> None:
self.id = id
self.pts = pts
class UpdateState:
__slots__ = (
"pts",
"qts",
"date",
"seq",
"channels",
)
def __init__(
self,
*,
pts: int,
qts: int,
date: int,
seq: int,
channels: List[ChannelState],
) -> None:
self.pts = pts
self.qts = qts
self.date = date
self.seq = seq
self.channels = channels
class Session:
__slots__ = ("dcs", "user", "state")
def __init__(
self,
*,
dcs: Optional[List[DataCenter]] = None,
user: Optional[User] = None,
state: Optional[UpdateState] = None,
):
self.dcs = dcs or []
self.user = user
self.state = state
def to_dict(self) -> Dict[str, Any]:
return {
"dcs": [
{
"id": dc.id,
"addr": dc.addr,
"auth": base64.b64encode(dc.auth).decode("ascii")
if dc.auth
else None,
}
for dc in self.dcs
],
"user": {
"id": self.user.id,
"dc": self.user.dc,
"bot": self.user.bot,
}
if self.user
else None,
"state": {
"pts": self.state.pts,
"qts": self.state.qts,
"date": self.state.date,
"seq": self.state.seq,
"channels": [
{"id": channel.id, "pts": channel.pts}
for channel in self.state.channels
],
}
if self.state
else None,
}
@classmethod
def from_dict(cls, dict: Dict[str, Any]) -> Self:
return cls(
dcs=[
DataCenter(
id=dc["id"],
addr=dc["addr"],
auth=base64.b64decode(dc["auth"])
if dc["auth"] is not None
else None,
)
for dc in dict["dcs"]
],
user=User(
id=dict["user"]["id"],
dc=dict["user"]["dc"],
bot=dict["user"]["bot"],
)
if dict["user"]
else None,
state=UpdateState(
pts=dict["state"]["pts"],
qts=dict["state"]["qts"],
date=dict["state"]["date"],
seq=dict["state"]["seq"],
channels=[
ChannelState(id=channel["id"], pts=channel["pts"])
for channel in dict["state"]["channels"]
],
)
if dict["state"]
else None,
)
class PtsInfo:
__slots__ = ("pts", "pts_count", "entry")

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
from ...tl import Request, abcs, functions, types
from ..chat import ChatHashCache
from ..session import ChannelState, UpdateState
from .adaptor import adapt, pts_info_from_update
from .defs import (
BOT_CHANNEL_DIFF_LIMIT,
@ -17,13 +18,11 @@ from .defs import (
NO_UPDATES_TIMEOUT,
POSSIBLE_GAP_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT,
ChannelState,
Entry,
Gap,
PossibleGap,
PrematureEndReason,
State,
UpdateState,
)

View File

@ -0,0 +1,159 @@
import base64
from typing import Any, Dict, List, Optional, Self
class DataCenter:
__slots__ = ("id", "addr", "auth")
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
self.id = id
self.addr = addr
self.auth = auth
class User:
__slots__ = ("id", "dc", "bot")
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
self.id = id
self.dc = dc
self.bot = bot
class ChannelState:
__slots__ = ("id", "pts")
def __init__(self, *, id: int, pts: int) -> None:
self.id = id
self.pts = pts
class UpdateState:
__slots__ = (
"pts",
"qts",
"date",
"seq",
"channels",
)
def __init__(
self,
*,
pts: int,
qts: int,
date: int,
seq: int,
channels: List[ChannelState],
) -> None:
self.pts = pts
self.qts = qts
self.date = date
self.seq = seq
self.channels = channels
class Session:
"""
A Telethon session.
A `Session` instance contains the required information to login into your
Telegram account. **Never** give the saved session file to anyone else or
make it public.
Leaking the session file will grant a bad actor complete access to your
account, including private conversations, groups you're part of and list
of contacts (though not secret chats).
If you think the session has been compromised, immediately terminate all
sessions through an official Telegram client to revoke the authorization.
"""
VERSION = 1
__slots__ = ("dcs", "user", "state")
def __init__(
self,
*,
dcs: Optional[List[DataCenter]] = None,
user: Optional[User] = None,
state: Optional[UpdateState] = None,
):
self.dcs = dcs or []
self.user = user
self.state = state
def to_dict(self) -> Dict[str, Any]:
return {
"v": self.VERSION,
"dcs": [
{
"id": dc.id,
"addr": dc.addr,
"auth": base64.b64encode(dc.auth).decode("ascii")
if dc.auth
else None,
}
for dc in self.dcs
],
"user": {
"id": self.user.id,
"dc": self.user.dc,
"bot": self.user.bot,
}
if self.user
else None,
"state": {
"pts": self.state.pts,
"qts": self.state.qts,
"date": self.state.date,
"seq": self.state.seq,
"channels": [
{"id": channel.id, "pts": channel.pts}
for channel in self.state.channels
],
}
if self.state
else None,
}
@classmethod
def from_dict(cls, dict: Dict[str, Any]) -> Self:
version = dict["v"]
if version != cls.VERSION:
raise ValueError(
f"cannot parse session format version {version} (expected {cls.VERSION})"
)
return cls(
dcs=[
DataCenter(
id=dc["id"],
addr=dc["addr"],
auth=base64.b64decode(dc["auth"])
if dc["auth"] is not None
else None,
)
for dc in dict["dcs"]
],
user=User(
id=dict["user"]["id"],
dc=dict["user"]["dc"],
bot=dict["user"]["bot"],
)
if dict["user"]
else None,
state=UpdateState(
pts=dict["state"]["pts"],
qts=dict["state"]["qts"],
date=dict["state"]["date"],
seq=dict["state"]["seq"],
channels=[
ChannelState(id=channel["id"], pts=channel["pts"])
for channel in dict["state"]["channels"]
],
)
if dict["state"]
else None,
)

View File

@ -0,0 +1,15 @@
from typing import Any
from .memory import MemorySession
from .storage import Storage
try:
from .sqlite import SqliteSession
except ImportError as e:
class SqliteSession(Storage): # type: ignore [no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise e from None
__all__ = ["MemorySession", "Storage", "SqliteSession"]

View File

@ -0,0 +1,28 @@
from typing import Optional
from ..session import Session
from .storage import Storage
class MemorySession(Storage):
"""
Session storage without persistence.
This is the simplest storage and is the one used by default.
Session data is only kept in memory and is not persisted to disk.
"""
__slots__ = ("session",)
def __init__(self, session: Optional[Session] = None):
self.session = session
async def load(self) -> Optional[Session]:
return self.session
async def save(self, session: Session) -> None:
self.session = session
async def delete(self) -> None:
self.session = None

View File

@ -0,0 +1,202 @@
import sqlite3
from pathlib import Path
from typing import Optional, Union
from ..session import ChannelState, DataCenter, Session, UpdateState, User
from .storage import Storage
EXTENSION = ".session"
CURRENT_VERSION = 10
class SqliteSession(Storage):
"""
Session storage backed by SQLite.
SQLite is a reliable way to persist data to disk and offers file locking.
Paths without extension will have '.session' appended to them.
This is by convention, and to make it harder to commit session files to
an VCS by accident (adding `*.session` to `.gitignore` will catch them).
"""
def __init__(self, file: Union[str, Path]):
path = Path(file)
if not path.suffix:
path = path.with_suffix(EXTENSION)
self._path = path
self._conn: Optional[sqlite3.Connection] = None
async def load(self) -> Optional[Session]:
conn = self._current_conn()
c = conn.cursor()
with conn:
version = self._get_or_init_version(c)
if version < CURRENT_VERSION:
if version == 7:
session = self._load_v7(c)
else:
raise NotImplementedError
self._reset(c)
self._get_or_init_version(c)
self._save_v10(c, session)
return self._load_v10(c)
async def save(self, session: Session) -> None:
conn = self._current_conn()
with conn:
self._save_v10(conn.cursor(), session)
conn.close()
self._conn = None
async def delete(self) -> None:
if self._conn:
self._conn.close()
self._conn = None
self._path.unlink()
def _current_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(self._path)
return self._conn
@staticmethod
def _load_v7(c: sqlite3.Cursor) -> Session:
# Session v7 format from telethon v1
c.execute("select dc_id, server_address, port, auth_key from sessions")
sessions = c.fetchall()
c.execute("select pts, qts, date, seq from update_state where id = 0")
state = c.fetchone()
c.execute("select id, pts from update_state where id != 0")
channelstate = c.fetchall()
return Session(
dcs=[
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
for (id, ip, port, auth) in sessions
],
user=None,
state=UpdateState(
pts=state[0],
qts=state[1],
date=state[2],
seq=state[3],
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
),
)
@staticmethod
def _load_v10(c: sqlite3.Cursor) -> Session:
c.execute("select * from datacenter")
datacenter = c.fetchall()
c.execute("select * from user")
user = c.fetchone()
c.execute("select * from state")
state = c.fetchone()
c.execute("select * from channelstate")
channelstate = c.fetchall()
return Session(
dcs=[
DataCenter(id=id, addr=addr, auth=auth)
for (id, addr, auth) in datacenter
],
user=User(id=user[0], dc=user[1], bot=bool(user[2])) if user else None,
state=UpdateState(
pts=state[0],
qts=state[1],
date=state[2],
seq=state[3],
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
),
)
@staticmethod
def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
c.execute("delete from datacenter")
c.execute("delete from user")
c.execute("delete from state")
c.execute("delete from channelstate")
c.executemany(
"insert into datacenter values (?, ?, ?)",
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
)
if user := session.user:
c.execute(
"insert into user values (?, ?, ?)", (user.id, user.dc, int(user.bot))
)
if state := session.state:
c.execute(
"insert into state values (?, ?, ?, ?)",
(state.pts, state.qts, state.date, state.seq),
)
c.executemany(
"insert into channelstate values (?, ?)",
[(channel.id, channel.pts) for channel in state.channels],
)
@staticmethod
def _reset(c: sqlite3.Cursor) -> None:
safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
c.execute("select name from sqlite_master where type='table'")
for (name,) in c.fetchall():
# Can't format arguments for table names. Regardless, it shouldn't
# be an SQL-injection because names come from `sqlite_master`.
# Just to be on the safe-side, check for r'\w+' nevertheless,
# avoiding referencing globals which could've been monkey-patched.
for char in name:
if char not in safe_chars or name.__len__() > 20:
raise ValueError(f"potentially unsafe table name: {name}")
c.execute(f"drop table {name}")
@staticmethod
def _get_or_init_version(c: sqlite3.Cursor) -> int:
c.execute(
"select name from sqlite_master where type='table' and name='version'"
)
if c.fetchone():
c.execute("select version from version")
res = c.fetchone()[0]
assert isinstance(res, int)
return res
else:
SqliteSession._create_tables(c)
c.execute("insert into version values (?)", (CURRENT_VERSION,))
return CURRENT_VERSION
@staticmethod
def _create_tables(c: sqlite3.Cursor) -> None:
c.executescript(
"""
create table version (
version integer primary key
);
create table datacenter(
id integer primary key,
addr text not null,
auth blob
);
create table user(
id integer primary key,
dc integer not null,
bot integer not null
);
create table state(
pts integer not null,
qts integer not null,
date integer not null,
seq integer not null
);
create table channelstate(
id integer primary key,
pts integer not null
);
"""
)

View File

@ -0,0 +1,30 @@
import abc
from typing import Optional
from ..session import Session
class Storage(abc.ABC):
@abc.abstractmethod
async def load(self) -> Optional[Session]:
"""
Load the `Session` instance, if any.
This method is called by the library prior to `connect`.
"""
@abc.abstractmethod
async def save(self, session: Session) -> None:
"""
Save the `Session` instance to persistent storage.
This method is called by the library post `disconnect`.
"""
@abc.abstractmethod
async def delete(self) -> None:
"""
Delete the saved `Session`.
This method is called by the library post `log_out`.
"""

View File

@ -13,13 +13,7 @@ async def test_ping_pong() -> None:
api_hash = os.getenv("TG_HASH")
assert api_id and api_id.isdigit()
assert api_hash
client = Client(
Config(
session=Session(),
api_id=int(api_id),
api_hash=api_hash,
)
)
client = Client(None, int(api_id), api_hash)
assert not client.connected
await client.connect()
assert client.connected

View File

@ -81,7 +81,6 @@ def test_unpack_two_at_once() -> None:
with raises(ValueError) as e:
transport.unpack(input, output)
e.match("bad seq")
assert output == expected_output
def test_unpack_twice() -> None: