mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 21:24:35 +03:00
Add session storages
This commit is contained in:
parent
16de3b274c
commit
aa83e7b043
|
@ -141,7 +141,4 @@ async def check_password(
|
|||
async def sign_out(self: Client) -> None:
|
||||
await self(functions.auth.log_out())
|
||||
|
||||
|
||||
def session(client: Client) -> Session:
|
||||
client._config.session.state = client._message_box.session_state()
|
||||
return client._config.session
|
||||
await self._storage.delete()
|
||||
|
|
|
@ -19,13 +19,20 @@ from typing import (
|
|||
)
|
||||
|
||||
from ...mtsender import Sender
|
||||
from ...session import ChatHashCache, MessageBox, PackedChat, Session
|
||||
from ...session import (
|
||||
ChatHashCache,
|
||||
MemorySession,
|
||||
MessageBox,
|
||||
PackedChat,
|
||||
Session,
|
||||
SqliteSession,
|
||||
Storage,
|
||||
)
|
||||
from ...tl import Request, abcs
|
||||
from ..events import Event
|
||||
from ..events.filters import Filter
|
||||
from ..types import (
|
||||
AsyncList,
|
||||
Chat,
|
||||
ChatLike,
|
||||
File,
|
||||
InFileLike,
|
||||
|
@ -41,15 +48,12 @@ from .auth import (
|
|||
check_password,
|
||||
is_authorized,
|
||||
request_login_code,
|
||||
session,
|
||||
sign_in,
|
||||
sign_out,
|
||||
)
|
||||
from .bots import InlineResult, inline_query
|
||||
from .chats import (
|
||||
get_participants,
|
||||
)
|
||||
from .dialogs import get_dialogs, delete_dialog
|
||||
from .chats import get_participants
|
||||
from .dialogs import delete_dialog, get_dialogs
|
||||
from .files import (
|
||||
download,
|
||||
iter_download,
|
||||
|
@ -88,38 +92,49 @@ from .updates import (
|
|||
remove_event_handler,
|
||||
set_handler_filter,
|
||||
)
|
||||
from .users import (
|
||||
get_me,
|
||||
input_to_peer,
|
||||
resolve_to_packed,
|
||||
)
|
||||
from .users import get_me, input_to_peer, resolve_to_packed
|
||||
|
||||
Return = TypeVar("Return")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, config: Config) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
session: Optional[Union[str, Path, Storage]],
|
||||
api_id: int,
|
||||
api_hash: Optional[str] = None,
|
||||
) -> None:
|
||||
self._sender: Optional[Sender] = None
|
||||
self._sender_lock = asyncio.Lock()
|
||||
self._dc_id = DEFAULT_DC
|
||||
self._config = config
|
||||
if isinstance(session, Storage):
|
||||
self._storage = session
|
||||
elif session is None:
|
||||
self._storage = MemorySession()
|
||||
else:
|
||||
self._storage = SqliteSession(session)
|
||||
self._config = Config(
|
||||
session=Session(),
|
||||
api_id=api_id,
|
||||
api_hash=api_hash or "",
|
||||
)
|
||||
self._message_box = MessageBox()
|
||||
self._chat_hashes = ChatHashCache(None)
|
||||
self._last_update_limit_warn: Optional[float] = None
|
||||
self._updates: asyncio.Queue[
|
||||
Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
|
||||
] = asyncio.Queue(maxsize=config.update_queue_limit or 0)
|
||||
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
||||
self._dispatcher: Optional[asyncio.Task[None]] = None
|
||||
self._downloader_map = object()
|
||||
self._handlers: Dict[
|
||||
Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
|
||||
] = {}
|
||||
|
||||
if self_user := config.session.user:
|
||||
if self_user := self._config.session.user:
|
||||
self._dc_id = self_user.dc
|
||||
if config.catch_up and config.session.state:
|
||||
self._message_box.load(config.session.state)
|
||||
if self._config.catch_up and self._config.session.state:
|
||||
self._message_box.load(self._config.session.state)
|
||||
|
||||
# ---
|
||||
|
||||
|
@ -450,15 +465,6 @@ class Client:
|
|||
def connected(self) -> bool:
|
||||
return connected(self)
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""
|
||||
Up-to-date session state, useful for persisting it to storage.
|
||||
|
||||
Mutating the returned object may cause the library to misbehave.
|
||||
"""
|
||||
return session(self)
|
||||
|
||||
def _build_message_map(
|
||||
self,
|
||||
result: abcs.Updates,
|
||||
|
|
|
@ -122,6 +122,8 @@ async def connect(self: Client) -> None:
|
|||
if self._sender:
|
||||
return
|
||||
|
||||
if session := await self._storage.load():
|
||||
self._config.session = session
|
||||
self._sender = await connect_sender(self._dc_id, self._config)
|
||||
|
||||
if self._message_box.is_empty() and self._config.session.user:
|
||||
|
@ -148,6 +150,9 @@ async def disconnect(self: Client) -> None:
|
|||
await self._sender.disconnect()
|
||||
self._sender = None
|
||||
|
||||
self._config.session.state = self._message_box.session_state()
|
||||
await self._storage.save(self._config.session)
|
||||
|
||||
|
||||
async def invoke_request(
|
||||
client: Client,
|
||||
|
|
|
@ -3,18 +3,15 @@ from .message_box import (
|
|||
BOT_CHANNEL_DIFF_LIMIT,
|
||||
NO_UPDATES_TIMEOUT,
|
||||
USER_CHANNEL_DIFF_LIMIT,
|
||||
ChannelState,
|
||||
DataCenter,
|
||||
Gap,
|
||||
MessageBox,
|
||||
PossibleGap,
|
||||
PrematureEndReason,
|
||||
PtsInfo,
|
||||
Session,
|
||||
State,
|
||||
UpdateState,
|
||||
User,
|
||||
)
|
||||
from .session import ChannelState, DataCenter, Session, UpdateState, User
|
||||
from .storage import MemorySession, SqliteSession, Storage
|
||||
|
||||
__all__ = [
|
||||
"ChatHashCache",
|
||||
|
@ -23,15 +20,18 @@ __all__ = [
|
|||
"BOT_CHANNEL_DIFF_LIMIT",
|
||||
"NO_UPDATES_TIMEOUT",
|
||||
"USER_CHANNEL_DIFF_LIMIT",
|
||||
"ChannelState",
|
||||
"DataCenter",
|
||||
"Gap",
|
||||
"PossibleGap",
|
||||
"PrematureEndReason",
|
||||
"PtsInfo",
|
||||
"Session",
|
||||
"State",
|
||||
"ChannelState",
|
||||
"DataCenter",
|
||||
"Session",
|
||||
"UpdateState",
|
||||
"User",
|
||||
"MessageBox",
|
||||
"MemorySession",
|
||||
"SqliteSession",
|
||||
"Storage",
|
||||
]
|
||||
|
|
|
@ -2,16 +2,11 @@ from .defs import (
|
|||
BOT_CHANNEL_DIFF_LIMIT,
|
||||
NO_UPDATES_TIMEOUT,
|
||||
USER_CHANNEL_DIFF_LIMIT,
|
||||
ChannelState,
|
||||
DataCenter,
|
||||
Gap,
|
||||
PossibleGap,
|
||||
PrematureEndReason,
|
||||
PtsInfo,
|
||||
Session,
|
||||
State,
|
||||
UpdateState,
|
||||
User,
|
||||
)
|
||||
from .messagebox import MessageBox
|
||||
|
||||
|
@ -19,15 +14,10 @@ __all__ = [
|
|||
"BOT_CHANNEL_DIFF_LIMIT",
|
||||
"NO_UPDATES_TIMEOUT",
|
||||
"USER_CHANNEL_DIFF_LIMIT",
|
||||
"ChannelState",
|
||||
"DataCenter",
|
||||
"Gap",
|
||||
"PossibleGap",
|
||||
"PrematureEndReason",
|
||||
"PtsInfo",
|
||||
"Session",
|
||||
"State",
|
||||
"UpdateState",
|
||||
"User",
|
||||
"MessageBox",
|
||||
]
|
||||
|
|
|
@ -1,144 +1,10 @@
|
|||
import base64
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Self, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from ...tl import abcs
|
||||
|
||||
|
||||
class DataCenter:
|
||||
__slots__ = ("id", "addr", "auth")
|
||||
|
||||
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
|
||||
self.id = id
|
||||
self.addr = addr
|
||||
self.auth = auth
|
||||
|
||||
|
||||
class User:
|
||||
__slots__ = ("id", "dc", "bot")
|
||||
|
||||
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
||||
self.id = id
|
||||
self.dc = dc
|
||||
self.bot = bot
|
||||
|
||||
|
||||
class ChannelState:
|
||||
__slots__ = ("id", "pts")
|
||||
|
||||
def __init__(self, *, id: int, pts: int) -> None:
|
||||
self.id = id
|
||||
self.pts = pts
|
||||
|
||||
|
||||
class UpdateState:
|
||||
__slots__ = (
|
||||
"pts",
|
||||
"qts",
|
||||
"date",
|
||||
"seq",
|
||||
"channels",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pts: int,
|
||||
qts: int,
|
||||
date: int,
|
||||
seq: int,
|
||||
channels: List[ChannelState],
|
||||
) -> None:
|
||||
self.pts = pts
|
||||
self.qts = qts
|
||||
self.date = date
|
||||
self.seq = seq
|
||||
self.channels = channels
|
||||
|
||||
|
||||
class Session:
|
||||
__slots__ = ("dcs", "user", "state")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dcs: Optional[List[DataCenter]] = None,
|
||||
user: Optional[User] = None,
|
||||
state: Optional[UpdateState] = None,
|
||||
):
|
||||
self.dcs = dcs or []
|
||||
self.user = user
|
||||
self.state = state
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"dcs": [
|
||||
{
|
||||
"id": dc.id,
|
||||
"addr": dc.addr,
|
||||
"auth": base64.b64encode(dc.auth).decode("ascii")
|
||||
if dc.auth
|
||||
else None,
|
||||
}
|
||||
for dc in self.dcs
|
||||
],
|
||||
"user": {
|
||||
"id": self.user.id,
|
||||
"dc": self.user.dc,
|
||||
"bot": self.user.bot,
|
||||
}
|
||||
if self.user
|
||||
else None,
|
||||
"state": {
|
||||
"pts": self.state.pts,
|
||||
"qts": self.state.qts,
|
||||
"date": self.state.date,
|
||||
"seq": self.state.seq,
|
||||
"channels": [
|
||||
{"id": channel.id, "pts": channel.pts}
|
||||
for channel in self.state.channels
|
||||
],
|
||||
}
|
||||
if self.state
|
||||
else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict: Dict[str, Any]) -> Self:
|
||||
return cls(
|
||||
dcs=[
|
||||
DataCenter(
|
||||
id=dc["id"],
|
||||
addr=dc["addr"],
|
||||
auth=base64.b64decode(dc["auth"])
|
||||
if dc["auth"] is not None
|
||||
else None,
|
||||
)
|
||||
for dc in dict["dcs"]
|
||||
],
|
||||
user=User(
|
||||
id=dict["user"]["id"],
|
||||
dc=dict["user"]["dc"],
|
||||
bot=dict["user"]["bot"],
|
||||
)
|
||||
if dict["user"]
|
||||
else None,
|
||||
state=UpdateState(
|
||||
pts=dict["state"]["pts"],
|
||||
qts=dict["state"]["qts"],
|
||||
date=dict["state"]["date"],
|
||||
seq=dict["state"]["seq"],
|
||||
channels=[
|
||||
ChannelState(id=channel["id"], pts=channel["pts"])
|
||||
for channel in dict["state"]["channels"]
|
||||
],
|
||||
)
|
||||
if dict["state"]
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
class PtsInfo:
|
||||
__slots__ = ("pts", "pts_count", "entry")
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
|||
|
||||
from ...tl import Request, abcs, functions, types
|
||||
from ..chat import ChatHashCache
|
||||
from ..session import ChannelState, UpdateState
|
||||
from .adaptor import adapt, pts_info_from_update
|
||||
from .defs import (
|
||||
BOT_CHANNEL_DIFF_LIMIT,
|
||||
|
@ -17,13 +18,11 @@ from .defs import (
|
|||
NO_UPDATES_TIMEOUT,
|
||||
POSSIBLE_GAP_TIMEOUT,
|
||||
USER_CHANNEL_DIFF_LIMIT,
|
||||
ChannelState,
|
||||
Entry,
|
||||
Gap,
|
||||
PossibleGap,
|
||||
PrematureEndReason,
|
||||
State,
|
||||
UpdateState,
|
||||
)
|
||||
|
||||
|
||||
|
|
159
client/src/telethon/_impl/session/session.py
Normal file
159
client/src/telethon/_impl/session/session.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import base64
|
||||
from typing import Any, Dict, List, Optional, Self
|
||||
|
||||
|
||||
class DataCenter:
|
||||
__slots__ = ("id", "addr", "auth")
|
||||
|
||||
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
|
||||
self.id = id
|
||||
self.addr = addr
|
||||
self.auth = auth
|
||||
|
||||
|
||||
class User:
|
||||
__slots__ = ("id", "dc", "bot")
|
||||
|
||||
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
||||
self.id = id
|
||||
self.dc = dc
|
||||
self.bot = bot
|
||||
|
||||
|
||||
class ChannelState:
|
||||
__slots__ = ("id", "pts")
|
||||
|
||||
def __init__(self, *, id: int, pts: int) -> None:
|
||||
self.id = id
|
||||
self.pts = pts
|
||||
|
||||
|
||||
class UpdateState:
|
||||
__slots__ = (
|
||||
"pts",
|
||||
"qts",
|
||||
"date",
|
||||
"seq",
|
||||
"channels",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pts: int,
|
||||
qts: int,
|
||||
date: int,
|
||||
seq: int,
|
||||
channels: List[ChannelState],
|
||||
) -> None:
|
||||
self.pts = pts
|
||||
self.qts = qts
|
||||
self.date = date
|
||||
self.seq = seq
|
||||
self.channels = channels
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
A Telethon session.
|
||||
|
||||
A `Session` instance contains the required information to login into your
|
||||
Telegram account. **Never** give the saved session file to anyone else or
|
||||
make it public.
|
||||
|
||||
Leaking the session file will grant a bad actor complete access to your
|
||||
account, including private conversations, groups you're part of and list
|
||||
of contacts (though not secret chats).
|
||||
|
||||
If you think the session has been compromised, immediately terminate all
|
||||
sessions through an official Telegram client to revoke the authorization.
|
||||
"""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
__slots__ = ("dcs", "user", "state")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dcs: Optional[List[DataCenter]] = None,
|
||||
user: Optional[User] = None,
|
||||
state: Optional[UpdateState] = None,
|
||||
):
|
||||
self.dcs = dcs or []
|
||||
self.user = user
|
||||
self.state = state
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"v": self.VERSION,
|
||||
"dcs": [
|
||||
{
|
||||
"id": dc.id,
|
||||
"addr": dc.addr,
|
||||
"auth": base64.b64encode(dc.auth).decode("ascii")
|
||||
if dc.auth
|
||||
else None,
|
||||
}
|
||||
for dc in self.dcs
|
||||
],
|
||||
"user": {
|
||||
"id": self.user.id,
|
||||
"dc": self.user.dc,
|
||||
"bot": self.user.bot,
|
||||
}
|
||||
if self.user
|
||||
else None,
|
||||
"state": {
|
||||
"pts": self.state.pts,
|
||||
"qts": self.state.qts,
|
||||
"date": self.state.date,
|
||||
"seq": self.state.seq,
|
||||
"channels": [
|
||||
{"id": channel.id, "pts": channel.pts}
|
||||
for channel in self.state.channels
|
||||
],
|
||||
}
|
||||
if self.state
|
||||
else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict: Dict[str, Any]) -> Self:
|
||||
version = dict["v"]
|
||||
if version != cls.VERSION:
|
||||
raise ValueError(
|
||||
f"cannot parse session format version {version} (expected {cls.VERSION})"
|
||||
)
|
||||
|
||||
return cls(
|
||||
dcs=[
|
||||
DataCenter(
|
||||
id=dc["id"],
|
||||
addr=dc["addr"],
|
||||
auth=base64.b64decode(dc["auth"])
|
||||
if dc["auth"] is not None
|
||||
else None,
|
||||
)
|
||||
for dc in dict["dcs"]
|
||||
],
|
||||
user=User(
|
||||
id=dict["user"]["id"],
|
||||
dc=dict["user"]["dc"],
|
||||
bot=dict["user"]["bot"],
|
||||
)
|
||||
if dict["user"]
|
||||
else None,
|
||||
state=UpdateState(
|
||||
pts=dict["state"]["pts"],
|
||||
qts=dict["state"]["qts"],
|
||||
date=dict["state"]["date"],
|
||||
seq=dict["state"]["seq"],
|
||||
channels=[
|
||||
ChannelState(id=channel["id"], pts=channel["pts"])
|
||||
for channel in dict["state"]["channels"]
|
||||
],
|
||||
)
|
||||
if dict["state"]
|
||||
else None,
|
||||
)
|
15
client/src/telethon/_impl/session/storage/__init__.py
Normal file
15
client/src/telethon/_impl/session/storage/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from .memory import MemorySession
|
||||
from .storage import Storage
|
||||
|
||||
try:
|
||||
from .sqlite import SqliteSession
|
||||
except ImportError as e:
|
||||
|
||||
class SqliteSession(Storage): # type: ignore [no-redef]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise e from None
|
||||
|
||||
|
||||
__all__ = ["MemorySession", "Storage", "SqliteSession"]
|
28
client/src/telethon/_impl/session/storage/memory.py
Normal file
28
client/src/telethon/_impl/session/storage/memory.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from typing import Optional
|
||||
|
||||
from ..session import Session
|
||||
from .storage import Storage
|
||||
|
||||
|
||||
class MemorySession(Storage):
|
||||
"""
|
||||
Session storage without persistence.
|
||||
|
||||
This is the simplest storage and is the one used by default.
|
||||
|
||||
Session data is only kept in memory and is not persisted to disk.
|
||||
"""
|
||||
|
||||
__slots__ = ("session",)
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
self.session = session
|
||||
|
||||
async def load(self) -> Optional[Session]:
|
||||
return self.session
|
||||
|
||||
async def save(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
async def delete(self) -> None:
|
||||
self.session = None
|
202
client/src/telethon/_impl/session/storage/sqlite.py
Normal file
202
client/src/telethon/_impl/session/storage/sqlite.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..session import ChannelState, DataCenter, Session, UpdateState, User
|
||||
from .storage import Storage
|
||||
|
||||
EXTENSION = ".session"
|
||||
CURRENT_VERSION = 10
|
||||
|
||||
|
||||
class SqliteSession(Storage):
|
||||
"""
|
||||
Session storage backed by SQLite.
|
||||
|
||||
SQLite is a reliable way to persist data to disk and offers file locking.
|
||||
|
||||
Paths without extension will have '.session' appended to them.
|
||||
This is by convention, and to make it harder to commit session files to
|
||||
an VCS by accident (adding `*.session` to `.gitignore` will catch them).
|
||||
"""
|
||||
|
||||
def __init__(self, file: Union[str, Path]):
|
||||
path = Path(file)
|
||||
if not path.suffix:
|
||||
path = path.with_suffix(EXTENSION)
|
||||
|
||||
self._path = path
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
|
||||
async def load(self) -> Optional[Session]:
|
||||
conn = self._current_conn()
|
||||
|
||||
c = conn.cursor()
|
||||
with conn:
|
||||
version = self._get_or_init_version(c)
|
||||
if version < CURRENT_VERSION:
|
||||
if version == 7:
|
||||
session = self._load_v7(c)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self._reset(c)
|
||||
self._get_or_init_version(c)
|
||||
self._save_v10(c, session)
|
||||
|
||||
return self._load_v10(c)
|
||||
|
||||
async def save(self, session: Session) -> None:
|
||||
conn = self._current_conn()
|
||||
with conn:
|
||||
self._save_v10(conn.cursor(), session)
|
||||
conn.close()
|
||||
self._conn = None
|
||||
|
||||
async def delete(self) -> None:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
self._path.unlink()
|
||||
|
||||
def _current_conn(self) -> sqlite3.Connection:
|
||||
if self._conn is None:
|
||||
self._conn = sqlite3.connect(self._path)
|
||||
|
||||
return self._conn
|
||||
|
||||
@staticmethod
|
||||
def _load_v7(c: sqlite3.Cursor) -> Session:
|
||||
# Session v7 format from telethon v1
|
||||
c.execute("select dc_id, server_address, port, auth_key from sessions")
|
||||
sessions = c.fetchall()
|
||||
c.execute("select pts, qts, date, seq from update_state where id = 0")
|
||||
state = c.fetchone()
|
||||
c.execute("select id, pts from update_state where id != 0")
|
||||
channelstate = c.fetchall()
|
||||
|
||||
return Session(
|
||||
dcs=[
|
||||
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
|
||||
for (id, ip, port, auth) in sessions
|
||||
],
|
||||
user=None,
|
||||
state=UpdateState(
|
||||
pts=state[0],
|
||||
qts=state[1],
|
||||
date=state[2],
|
||||
seq=state[3],
|
||||
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _load_v10(c: sqlite3.Cursor) -> Session:
|
||||
c.execute("select * from datacenter")
|
||||
datacenter = c.fetchall()
|
||||
c.execute("select * from user")
|
||||
user = c.fetchone()
|
||||
c.execute("select * from state")
|
||||
state = c.fetchone()
|
||||
c.execute("select * from channelstate")
|
||||
channelstate = c.fetchall()
|
||||
|
||||
return Session(
|
||||
dcs=[
|
||||
DataCenter(id=id, addr=addr, auth=auth)
|
||||
for (id, addr, auth) in datacenter
|
||||
],
|
||||
user=User(id=user[0], dc=user[1], bot=bool(user[2])) if user else None,
|
||||
state=UpdateState(
|
||||
pts=state[0],
|
||||
qts=state[1],
|
||||
date=state[2],
|
||||
seq=state[3],
|
||||
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
|
||||
c.execute("delete from datacenter")
|
||||
c.execute("delete from user")
|
||||
c.execute("delete from state")
|
||||
c.execute("delete from channelstate")
|
||||
c.executemany(
|
||||
"insert into datacenter values (?, ?, ?)",
|
||||
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
|
||||
)
|
||||
if user := session.user:
|
||||
c.execute(
|
||||
"insert into user values (?, ?, ?)", (user.id, user.dc, int(user.bot))
|
||||
)
|
||||
if state := session.state:
|
||||
c.execute(
|
||||
"insert into state values (?, ?, ?, ?)",
|
||||
(state.pts, state.qts, state.date, state.seq),
|
||||
)
|
||||
c.executemany(
|
||||
"insert into channelstate values (?, ?)",
|
||||
[(channel.id, channel.pts) for channel in state.channels],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reset(c: sqlite3.Cursor) -> None:
|
||||
safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
c.execute("select name from sqlite_master where type='table'")
|
||||
for (name,) in c.fetchall():
|
||||
# Can't format arguments for table names. Regardless, it shouldn't
|
||||
# be an SQL-injection because names come from `sqlite_master`.
|
||||
# Just to be on the safe-side, check for r'\w+' nevertheless,
|
||||
# avoiding referencing globals which could've been monkey-patched.
|
||||
for char in name:
|
||||
if char not in safe_chars or name.__len__() > 20:
|
||||
raise ValueError(f"potentially unsafe table name: {name}")
|
||||
|
||||
c.execute(f"drop table {name}")
|
||||
|
||||
@staticmethod
|
||||
def _get_or_init_version(c: sqlite3.Cursor) -> int:
|
||||
c.execute(
|
||||
"select name from sqlite_master where type='table' and name='version'"
|
||||
)
|
||||
if c.fetchone():
|
||||
c.execute("select version from version")
|
||||
res = c.fetchone()[0]
|
||||
assert isinstance(res, int)
|
||||
return res
|
||||
else:
|
||||
SqliteSession._create_tables(c)
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
return CURRENT_VERSION
|
||||
|
||||
@staticmethod
|
||||
def _create_tables(c: sqlite3.Cursor) -> None:
|
||||
c.executescript(
|
||||
"""
|
||||
create table version (
|
||||
version integer primary key
|
||||
);
|
||||
create table datacenter(
|
||||
id integer primary key,
|
||||
addr text not null,
|
||||
auth blob
|
||||
);
|
||||
create table user(
|
||||
id integer primary key,
|
||||
dc integer not null,
|
||||
bot integer not null
|
||||
);
|
||||
create table state(
|
||||
pts integer not null,
|
||||
qts integer not null,
|
||||
date integer not null,
|
||||
seq integer not null
|
||||
);
|
||||
create table channelstate(
|
||||
id integer primary key,
|
||||
pts integer not null
|
||||
);
|
||||
"""
|
||||
)
|
30
client/src/telethon/_impl/session/storage/storage.py
Normal file
30
client/src/telethon/_impl/session/storage/storage.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from ..session import Session
|
||||
|
||||
|
||||
class Storage(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def load(self) -> Optional[Session]:
|
||||
"""
|
||||
Load the `Session` instance, if any.
|
||||
|
||||
This method is called by the library prior to `connect`.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def save(self, session: Session) -> None:
|
||||
"""
|
||||
Save the `Session` instance to persistent storage.
|
||||
|
||||
This method is called by the library post `disconnect`.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self) -> None:
|
||||
"""
|
||||
Delete the saved `Session`.
|
||||
|
||||
This method is called by the library post `log_out`.
|
||||
"""
|
|
@ -13,13 +13,7 @@ async def test_ping_pong() -> None:
|
|||
api_hash = os.getenv("TG_HASH")
|
||||
assert api_id and api_id.isdigit()
|
||||
assert api_hash
|
||||
client = Client(
|
||||
Config(
|
||||
session=Session(),
|
||||
api_id=int(api_id),
|
||||
api_hash=api_hash,
|
||||
)
|
||||
)
|
||||
client = Client(None, int(api_id), api_hash)
|
||||
assert not client.connected
|
||||
await client.connect()
|
||||
assert client.connected
|
||||
|
|
|
@ -81,7 +81,6 @@ def test_unpack_two_at_once() -> None:
|
|||
with raises(ValueError) as e:
|
||||
transport.unpack(input, output)
|
||||
e.match("bad seq")
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_unpack_twice() -> None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user