diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index dc16c142..bf63bd5d 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -141,7 +141,4 @@ async def check_password( async def sign_out(self: Client) -> None: await self(functions.auth.log_out()) - -def session(client: Client) -> Session: - client._config.session.state = client._message_box.session_state() - return client._config.session + await self._storage.delete() diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index bcd3bd34..f585d009 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -19,13 +19,20 @@ from typing import ( ) from ...mtsender import Sender -from ...session import ChatHashCache, MessageBox, PackedChat, Session +from ...session import ( + ChatHashCache, + MemorySession, + MessageBox, + PackedChat, + Session, + SqliteSession, + Storage, +) from ...tl import Request, abcs from ..events import Event from ..events.filters import Filter from ..types import ( AsyncList, - Chat, ChatLike, File, InFileLike, @@ -41,15 +48,12 @@ from .auth import ( check_password, is_authorized, request_login_code, - session, sign_in, sign_out, ) from .bots import InlineResult, inline_query -from .chats import ( - get_participants, -) -from .dialogs import get_dialogs, delete_dialog +from .chats import get_participants +from .dialogs import delete_dialog, get_dialogs from .files import ( download, iter_download, @@ -88,38 +92,49 @@ from .updates import ( remove_event_handler, set_handler_filter, ) -from .users import ( - get_me, - input_to_peer, - resolve_to_packed, -) +from .users import get_me, input_to_peer, resolve_to_packed Return = TypeVar("Return") T = TypeVar("T") class Client: - def __init__(self, config: Config) -> None: + def __init__( + self, + session: Optional[Union[str, Path, Storage]], + api_id: int, + api_hash: Optional[str] = None, + ) -> None: self._sender: Optional[Sender] = None self._sender_lock = asyncio.Lock() self._dc_id = DEFAULT_DC - self._config = config + if isinstance(session, Storage): + self._storage = session + elif session is None: + self._storage = MemorySession() + else: + self._storage = SqliteSession(session) + self._config = Config( + session=Session(), + api_id=api_id, + api_hash=api_hash or "", + ) self._message_box = MessageBox() self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None self._updates: asyncio.Queue[ Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]] - ] = asyncio.Queue(maxsize=config.update_queue_limit or 0) + ] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0) self._dispatcher: Optional[asyncio.Task[None]] = None self._downloader_map = object() self._handlers: Dict[ Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]] ] = {} - if self_user := config.session.user: + if self_user := self._config.session.user: self._dc_id = self_user.dc - if config.catch_up and config.session.state: - self._message_box.load(config.session.state) + if self._config.catch_up and self._config.session.state: + self._message_box.load(self._config.session.state) # --- @@ -450,15 +465,6 @@ class Client: def connected(self) -> bool: return connected(self) - @property - def session(self) -> Session: - """ - Up-to-date session state, useful for persisting it to storage. - - Mutating the returned object may cause the library to misbehave. - """ - return session(self) - def _build_message_map( self, result: abcs.Updates, diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index bee18622..9d087405 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -122,6 +122,8 @@ async def connect(self: Client) -> None: if self._sender: return + if session := await self._storage.load(): + self._config.session = session self._sender = await connect_sender(self._dc_id, self._config) if self._message_box.is_empty() and self._config.session.user: @@ -148,6 +150,9 @@ async def disconnect(self: Client) -> None: await self._sender.disconnect() self._sender = None + self._config.session.state = self._message_box.session_state() + await self._storage.save(self._config.session) + async def invoke_request( client: Client, diff --git a/client/src/telethon/_impl/session/__init__.py b/client/src/telethon/_impl/session/__init__.py index 9bf52f07..e0d4ff31 100644 --- a/client/src/telethon/_impl/session/__init__.py +++ b/client/src/telethon/_impl/session/__init__.py @@ -3,18 +3,15 @@ from .message_box import ( BOT_CHANNEL_DIFF_LIMIT, NO_UPDATES_TIMEOUT, USER_CHANNEL_DIFF_LIMIT, - ChannelState, - DataCenter, Gap, MessageBox, PossibleGap, PrematureEndReason, PtsInfo, - Session, State, - UpdateState, - User, ) +from .session import ChannelState, DataCenter, Session, UpdateState, User +from .storage import MemorySession, SqliteSession, Storage __all__ = [ "ChatHashCache", @@ -23,15 +20,18 @@ __all__ = [ "BOT_CHANNEL_DIFF_LIMIT", "NO_UPDATES_TIMEOUT", "USER_CHANNEL_DIFF_LIMIT", - "ChannelState", - "DataCenter", "Gap", "PossibleGap", "PrematureEndReason", "PtsInfo", - "Session", "State", + "ChannelState", + "DataCenter", + "Session", "UpdateState", "User", "MessageBox", + "MemorySession", + "SqliteSession", + "Storage", ] diff --git a/client/src/telethon/_impl/session/message_box/__init__.py b/client/src/telethon/_impl/session/message_box/__init__.py index 49cc5722..2d075caf 100644 --- a/client/src/telethon/_impl/session/message_box/__init__.py +++ b/client/src/telethon/_impl/session/message_box/__init__.py @@ -2,16 +2,11 @@ from .defs import ( BOT_CHANNEL_DIFF_LIMIT, NO_UPDATES_TIMEOUT, USER_CHANNEL_DIFF_LIMIT, - ChannelState, - DataCenter, Gap, PossibleGap, PrematureEndReason, PtsInfo, - Session, State, - UpdateState, - User, ) from .messagebox import MessageBox @@ -19,15 +14,10 @@ __all__ = [ "BOT_CHANNEL_DIFF_LIMIT", "NO_UPDATES_TIMEOUT", "USER_CHANNEL_DIFF_LIMIT", - "ChannelState", - "DataCenter", "Gap", "PossibleGap", "PrematureEndReason", "PtsInfo", - "Session", "State", - "UpdateState", - "User", "MessageBox", ] diff --git a/client/src/telethon/_impl/session/message_box/defs.py b/client/src/telethon/_impl/session/message_box/defs.py index 28efd50f..3b2cb262 100644 --- a/client/src/telethon/_impl/session/message_box/defs.py +++ b/client/src/telethon/_impl/session/message_box/defs.py @@ -1,144 +1,10 @@ -import base64 import logging from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Self, Union +from typing import List, Literal, Union from ...tl import abcs -class DataCenter: - __slots__ = ("id", "addr", "auth") - - def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None: - self.id = id - self.addr = addr - self.auth = auth - - -class User: - __slots__ = ("id", "dc", "bot") - - def __init__(self, *, id: int, dc: int, bot: bool) -> None: - self.id = id - self.dc = dc - self.bot = bot - - -class ChannelState: - __slots__ = ("id", "pts") - - def __init__(self, *, id: int, pts: int) -> None: - self.id = id - self.pts = pts - - -class UpdateState: - __slots__ = ( - "pts", - "qts", - "date", - "seq", - "channels", - ) - - def __init__( - self, - *, - pts: int, - qts: int, - date: int, - seq: int, - channels: List[ChannelState], - ) -> None: - self.pts = pts - self.qts = qts - self.date = date - self.seq = seq - self.channels = channels - - -class Session: - __slots__ = ("dcs", "user", "state") - - def __init__( - self, - *, - dcs: Optional[List[DataCenter]] = None, - user: Optional[User] = None, - state: Optional[UpdateState] = None, - ): - self.dcs = dcs or [] - self.user = user - self.state = state - - def to_dict(self) -> Dict[str, Any]: - return { - "dcs": [ - { - "id": dc.id, - "addr": dc.addr, - "auth": base64.b64encode(dc.auth).decode("ascii") - if dc.auth - else None, - } - for dc in self.dcs - ], - "user": { - "id": self.user.id, - "dc": self.user.dc, - "bot": self.user.bot, - } - if self.user - else None, - "state": { - "pts": self.state.pts, - "qts": self.state.qts, - "date": self.state.date, - "seq": self.state.seq, - "channels": [ - {"id": channel.id, "pts": channel.pts} - for channel in self.state.channels - ], - } - if self.state - else None, - } - - @classmethod - def from_dict(cls, dict: Dict[str, Any]) -> Self: - return cls( - dcs=[ - DataCenter( - id=dc["id"], - addr=dc["addr"], - auth=base64.b64decode(dc["auth"]) - if dc["auth"] is not None - else None, - ) - for dc in dict["dcs"] - ], - user=User( - id=dict["user"]["id"], - dc=dict["user"]["dc"], - bot=dict["user"]["bot"], - ) - if dict["user"] - else None, - state=UpdateState( - pts=dict["state"]["pts"], - qts=dict["state"]["qts"], - date=dict["state"]["date"], - seq=dict["state"]["seq"], - channels=[ - ChannelState(id=channel["id"], pts=channel["pts"]) - for channel in dict["state"]["channels"] - ], - ) - if dict["state"] - else None, - ) - - class PtsInfo: __slots__ = ("pts", "pts_count", "entry") diff --git a/client/src/telethon/_impl/session/message_box/messagebox.py b/client/src/telethon/_impl/session/message_box/messagebox.py index 88888de5..4bc9609b 100644 --- a/client/src/telethon/_impl/session/message_box/messagebox.py +++ b/client/src/telethon/_impl/session/message_box/messagebox.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple from ...tl import Request, abcs, functions, types from ..chat import ChatHashCache +from ..session import ChannelState, UpdateState from .adaptor import adapt, pts_info_from_update from .defs import ( BOT_CHANNEL_DIFF_LIMIT, @@ -17,13 +18,11 @@ from .defs import ( NO_UPDATES_TIMEOUT, POSSIBLE_GAP_TIMEOUT, USER_CHANNEL_DIFF_LIMIT, - ChannelState, Entry, Gap, PossibleGap, PrematureEndReason, State, - UpdateState, ) diff --git a/client/src/telethon/_impl/session/session.py b/client/src/telethon/_impl/session/session.py new file mode 100644 index 00000000..68f77854 --- /dev/null +++ b/client/src/telethon/_impl/session/session.py @@ -0,0 +1,159 @@ +import base64 +from typing import Any, Dict, List, Optional, Self + + +class DataCenter: + __slots__ = ("id", "addr", "auth") + + def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None: + self.id = id + self.addr = addr + self.auth = auth + + +class User: + __slots__ = ("id", "dc", "bot") + + def __init__(self, *, id: int, dc: int, bot: bool) -> None: + self.id = id + self.dc = dc + self.bot = bot + + +class ChannelState: + __slots__ = ("id", "pts") + + def __init__(self, *, id: int, pts: int) -> None: + self.id = id + self.pts = pts + + +class UpdateState: + __slots__ = ( + "pts", + "qts", + "date", + "seq", + "channels", + ) + + def __init__( + self, + *, + pts: int, + qts: int, + date: int, + seq: int, + channels: List[ChannelState], + ) -> None: + self.pts = pts + self.qts = qts + self.date = date + self.seq = seq + self.channels = channels + + +class Session: + """ + A Telethon session. + + A `Session` instance contains the required information to login into your + Telegram account. **Never** give the saved session file to anyone else or + make it public. + + Leaking the session file will grant a bad actor complete access to your + account, including private conversations, groups you're part of and list + of contacts (though not secret chats). + + If you think the session has been compromised, immediately terminate all + sessions through an official Telegram client to revoke the authorization. + """ + + VERSION = 1 + + __slots__ = ("dcs", "user", "state") + + def __init__( + self, + *, + dcs: Optional[List[DataCenter]] = None, + user: Optional[User] = None, + state: Optional[UpdateState] = None, + ): + self.dcs = dcs or [] + self.user = user + self.state = state + + def to_dict(self) -> Dict[str, Any]: + return { + "v": self.VERSION, + "dcs": [ + { + "id": dc.id, + "addr": dc.addr, + "auth": base64.b64encode(dc.auth).decode("ascii") + if dc.auth + else None, + } + for dc in self.dcs + ], + "user": { + "id": self.user.id, + "dc": self.user.dc, + "bot": self.user.bot, + } + if self.user + else None, + "state": { + "pts": self.state.pts, + "qts": self.state.qts, + "date": self.state.date, + "seq": self.state.seq, + "channels": [ + {"id": channel.id, "pts": channel.pts} + for channel in self.state.channels + ], + } + if self.state + else None, + } + + @classmethod + def from_dict(cls, dict: Dict[str, Any]) -> Self: + version = dict["v"] + if version != cls.VERSION: + raise ValueError( + f"cannot parse session format version {version} (expected {cls.VERSION})" + ) + + return cls( + dcs=[ + DataCenter( + id=dc["id"], + addr=dc["addr"], + auth=base64.b64decode(dc["auth"]) + if dc["auth"] is not None + else None, + ) + for dc in dict["dcs"] + ], + user=User( + id=dict["user"]["id"], + dc=dict["user"]["dc"], + bot=dict["user"]["bot"], + ) + if dict["user"] + else None, + state=UpdateState( + pts=dict["state"]["pts"], + qts=dict["state"]["qts"], + date=dict["state"]["date"], + seq=dict["state"]["seq"], + channels=[ + ChannelState(id=channel["id"], pts=channel["pts"]) + for channel in dict["state"]["channels"] + ], + ) + if dict["state"] + else None, + ) diff --git a/client/src/telethon/_impl/session/storage/__init__.py b/client/src/telethon/_impl/session/storage/__init__.py new file mode 100644 index 00000000..285e8822 --- /dev/null +++ b/client/src/telethon/_impl/session/storage/__init__.py @@ -0,0 +1,15 @@ +from typing import Any + +from .memory import MemorySession +from .storage import Storage + +try: + from .sqlite import SqliteSession +except ImportError as e: + + class SqliteSession(Storage): # type: ignore [no-redef] + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise e from None + + +__all__ = ["MemorySession", "Storage", "SqliteSession"] diff --git a/client/src/telethon/_impl/session/storage/memory.py b/client/src/telethon/_impl/session/storage/memory.py new file mode 100644 index 00000000..d7458d9b --- /dev/null +++ b/client/src/telethon/_impl/session/storage/memory.py @@ -0,0 +1,28 @@ +from typing import Optional + +from ..session import Session +from .storage import Storage + + +class MemorySession(Storage): + """ + Session storage without persistence. + + This is the simplest storage and is the one used by default. + + Session data is only kept in memory and is not persisted to disk. + """ + + __slots__ = ("session",) + + def __init__(self, session: Optional[Session] = None): + self.session = session + + async def load(self) -> Optional[Session]: + return self.session + + async def save(self, session: Session) -> None: + self.session = session + + async def delete(self) -> None: + self.session = None diff --git a/client/src/telethon/_impl/session/storage/sqlite.py b/client/src/telethon/_impl/session/storage/sqlite.py new file mode 100644 index 00000000..a41deb52 --- /dev/null +++ b/client/src/telethon/_impl/session/storage/sqlite.py @@ -0,0 +1,202 @@ +import sqlite3 +from pathlib import Path +from typing import Optional, Union + +from ..session import ChannelState, DataCenter, Session, UpdateState, User +from .storage import Storage + +EXTENSION = ".session" +CURRENT_VERSION = 10 + + +class SqliteSession(Storage): + """ + Session storage backed by SQLite. + + SQLite is a reliable way to persist data to disk and offers file locking. + + Paths without extension will have '.session' appended to them. + This is by convention, and to make it harder to commit session files to + an VCS by accident (adding `*.session` to `.gitignore` will catch them). + """ + + def __init__(self, file: Union[str, Path]): + path = Path(file) + if not path.suffix: + path = path.with_suffix(EXTENSION) + + self._path = path + self._conn: Optional[sqlite3.Connection] = None + + async def load(self) -> Optional[Session]: + conn = self._current_conn() + + c = conn.cursor() + with conn: + version = self._get_or_init_version(c) + if version < CURRENT_VERSION: + if version == 7: + session = self._load_v7(c) + else: + raise NotImplementedError + + self._reset(c) + self._get_or_init_version(c) + self._save_v10(c, session) + + return self._load_v10(c) + + async def save(self, session: Session) -> None: + conn = self._current_conn() + with conn: + self._save_v10(conn.cursor(), session) + conn.close() + self._conn = None + + async def delete(self) -> None: + if self._conn: + self._conn.close() + self._conn = None + self._path.unlink() + + def _current_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self._path) + + return self._conn + + @staticmethod + def _load_v7(c: sqlite3.Cursor) -> Session: + # Session v7 format from telethon v1 + c.execute("select dc_id, server_address, port, auth_key from sessions") + sessions = c.fetchall() + c.execute("select pts, qts, date, seq from update_state where id = 0") + state = c.fetchone() + c.execute("select id, pts from update_state where id != 0") + channelstate = c.fetchall() + + return Session( + dcs=[ + DataCenter(id=id, addr=f"{ip}:{port}", auth=auth) + for (id, ip, port, auth) in sessions + ], + user=None, + state=UpdateState( + pts=state[0], + qts=state[1], + date=state[2], + seq=state[3], + channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate], + ), + ) + + @staticmethod + def _load_v10(c: sqlite3.Cursor) -> Session: + c.execute("select * from datacenter") + datacenter = c.fetchall() + c.execute("select * from user") + user = c.fetchone() + c.execute("select * from state") + state = c.fetchone() + c.execute("select * from channelstate") + channelstate = c.fetchall() + + return Session( + dcs=[ + DataCenter(id=id, addr=addr, auth=auth) + for (id, addr, auth) in datacenter + ], + user=User(id=user[0], dc=user[1], bot=bool(user[2])) if user else None, + state=UpdateState( + pts=state[0], + qts=state[1], + date=state[2], + seq=state[3], + channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate], + ), + ) + + @staticmethod + def _save_v10(c: sqlite3.Cursor, session: Session) -> None: + c.execute("delete from datacenter") + c.execute("delete from user") + c.execute("delete from state") + c.execute("delete from channelstate") + c.executemany( + "insert into datacenter values (?, ?, ?)", + [(dc.id, dc.addr, dc.auth) for dc in session.dcs], + ) + if user := session.user: + c.execute( + "insert into user values (?, ?, ?)", (user.id, user.dc, int(user.bot)) + ) + if state := session.state: + c.execute( + "insert into state values (?, ?, ?, ?)", + (state.pts, state.qts, state.date, state.seq), + ) + c.executemany( + "insert into channelstate values (?, ?)", + [(channel.id, channel.pts) for channel in state.channels], + ) + + @staticmethod + def _reset(c: sqlite3.Cursor) -> None: + safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + c.execute("select name from sqlite_master where type='table'") + for (name,) in c.fetchall(): + # Can't format arguments for table names. Regardless, it shouldn't + # be an SQL-injection because names come from `sqlite_master`. + # Just to be on the safe-side, check for r'\w+' nevertheless, + # avoiding referencing globals which could've been monkey-patched. + for char in name: + if char not in safe_chars or name.__len__() > 20: + raise ValueError(f"potentially unsafe table name: {name}") + + c.execute(f"drop table {name}") + + @staticmethod + def _get_or_init_version(c: sqlite3.Cursor) -> int: + c.execute( + "select name from sqlite_master where type='table' and name='version'" + ) + if c.fetchone(): + c.execute("select version from version") + res = c.fetchone()[0] + assert isinstance(res, int) + return res + else: + SqliteSession._create_tables(c) + c.execute("insert into version values (?)", (CURRENT_VERSION,)) + return CURRENT_VERSION + + @staticmethod + def _create_tables(c: sqlite3.Cursor) -> None: + c.executescript( + """ + create table version ( + version integer primary key + ); + create table datacenter( + id integer primary key, + addr text not null, + auth blob + ); + create table user( + id integer primary key, + dc integer not null, + bot integer not null + ); + create table state( + pts integer not null, + qts integer not null, + date integer not null, + seq integer not null + ); + create table channelstate( + id integer primary key, + pts integer not null + ); + """ + ) diff --git a/client/src/telethon/_impl/session/storage/storage.py b/client/src/telethon/_impl/session/storage/storage.py new file mode 100644 index 00000000..7cd3e806 --- /dev/null +++ b/client/src/telethon/_impl/session/storage/storage.py @@ -0,0 +1,30 @@ +import abc +from typing import Optional + +from ..session import Session + + +class Storage(abc.ABC): + @abc.abstractmethod + async def load(self) -> Optional[Session]: + """ + Load the `Session` instance, if any. + + This method is called by the library prior to `connect`. + """ + + @abc.abstractmethod + async def save(self, session: Session) -> None: + """ + Save the `Session` instance to persistent storage. + + This method is called by the library post `disconnect`. + """ + + @abc.abstractmethod + async def delete(self) -> None: + """ + Delete the saved `Session`. + + This method is called by the library post `log_out`. + """ diff --git a/client/tests/client_test.py b/client/tests/client_test.py index 46337006..4401a65b 100644 --- a/client/tests/client_test.py +++ b/client/tests/client_test.py @@ -13,13 +13,7 @@ async def test_ping_pong() -> None: api_hash = os.getenv("TG_HASH") assert api_id and api_id.isdigit() assert api_hash - client = Client( - Config( - session=Session(), - api_id=int(api_id), - api_hash=api_hash, - ) - ) + client = Client(None, int(api_id), api_hash) assert not client.connected await client.connect() assert client.connected diff --git a/client/tests/transport/full_test.py b/client/tests/transport/full_test.py index 4fc37cb1..61e37c54 100644 --- a/client/tests/transport/full_test.py +++ b/client/tests/transport/full_test.py @@ -81,7 +81,6 @@ def test_unpack_two_at_once() -> None: with raises(ValueError) as e: transport.unpack(input, output) e.match("bad seq") - assert output == expected_output def test_unpack_twice() -> None: