mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-10 16:40:57 +03:00
Add session storages
This commit is contained in:
parent
16de3b274c
commit
aa83e7b043
|
@ -141,7 +141,4 @@ async def check_password(
|
||||||
async def sign_out(self: Client) -> None:
|
async def sign_out(self: Client) -> None:
|
||||||
await self(functions.auth.log_out())
|
await self(functions.auth.log_out())
|
||||||
|
|
||||||
|
await self._storage.delete()
|
||||||
def session(client: Client) -> Session:
|
|
||||||
client._config.session.state = client._message_box.session_state()
|
|
||||||
return client._config.session
|
|
||||||
|
|
|
@ -19,13 +19,20 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...mtsender import Sender
|
from ...mtsender import Sender
|
||||||
from ...session import ChatHashCache, MessageBox, PackedChat, Session
|
from ...session import (
|
||||||
|
ChatHashCache,
|
||||||
|
MemorySession,
|
||||||
|
MessageBox,
|
||||||
|
PackedChat,
|
||||||
|
Session,
|
||||||
|
SqliteSession,
|
||||||
|
Storage,
|
||||||
|
)
|
||||||
from ...tl import Request, abcs
|
from ...tl import Request, abcs
|
||||||
from ..events import Event
|
from ..events import Event
|
||||||
from ..events.filters import Filter
|
from ..events.filters import Filter
|
||||||
from ..types import (
|
from ..types import (
|
||||||
AsyncList,
|
AsyncList,
|
||||||
Chat,
|
|
||||||
ChatLike,
|
ChatLike,
|
||||||
File,
|
File,
|
||||||
InFileLike,
|
InFileLike,
|
||||||
|
@ -41,15 +48,12 @@ from .auth import (
|
||||||
check_password,
|
check_password,
|
||||||
is_authorized,
|
is_authorized,
|
||||||
request_login_code,
|
request_login_code,
|
||||||
session,
|
|
||||||
sign_in,
|
sign_in,
|
||||||
sign_out,
|
sign_out,
|
||||||
)
|
)
|
||||||
from .bots import InlineResult, inline_query
|
from .bots import InlineResult, inline_query
|
||||||
from .chats import (
|
from .chats import get_participants
|
||||||
get_participants,
|
from .dialogs import delete_dialog, get_dialogs
|
||||||
)
|
|
||||||
from .dialogs import get_dialogs, delete_dialog
|
|
||||||
from .files import (
|
from .files import (
|
||||||
download,
|
download,
|
||||||
iter_download,
|
iter_download,
|
||||||
|
@ -88,38 +92,49 @@ from .updates import (
|
||||||
remove_event_handler,
|
remove_event_handler,
|
||||||
set_handler_filter,
|
set_handler_filter,
|
||||||
)
|
)
|
||||||
from .users import (
|
from .users import get_me, input_to_peer, resolve_to_packed
|
||||||
get_me,
|
|
||||||
input_to_peer,
|
|
||||||
resolve_to_packed,
|
|
||||||
)
|
|
||||||
|
|
||||||
Return = TypeVar("Return")
|
Return = TypeVar("Return")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, config: Config) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: Optional[Union[str, Path, Storage]],
|
||||||
|
api_id: int,
|
||||||
|
api_hash: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self._sender: Optional[Sender] = None
|
self._sender: Optional[Sender] = None
|
||||||
self._sender_lock = asyncio.Lock()
|
self._sender_lock = asyncio.Lock()
|
||||||
self._dc_id = DEFAULT_DC
|
self._dc_id = DEFAULT_DC
|
||||||
self._config = config
|
if isinstance(session, Storage):
|
||||||
|
self._storage = session
|
||||||
|
elif session is None:
|
||||||
|
self._storage = MemorySession()
|
||||||
|
else:
|
||||||
|
self._storage = SqliteSession(session)
|
||||||
|
self._config = Config(
|
||||||
|
session=Session(),
|
||||||
|
api_id=api_id,
|
||||||
|
api_hash=api_hash or "",
|
||||||
|
)
|
||||||
self._message_box = MessageBox()
|
self._message_box = MessageBox()
|
||||||
self._chat_hashes = ChatHashCache(None)
|
self._chat_hashes = ChatHashCache(None)
|
||||||
self._last_update_limit_warn: Optional[float] = None
|
self._last_update_limit_warn: Optional[float] = None
|
||||||
self._updates: asyncio.Queue[
|
self._updates: asyncio.Queue[
|
||||||
Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
|
Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
|
||||||
] = asyncio.Queue(maxsize=config.update_queue_limit or 0)
|
] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
||||||
self._dispatcher: Optional[asyncio.Task[None]] = None
|
self._dispatcher: Optional[asyncio.Task[None]] = None
|
||||||
self._downloader_map = object()
|
self._downloader_map = object()
|
||||||
self._handlers: Dict[
|
self._handlers: Dict[
|
||||||
Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
|
Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
if self_user := config.session.user:
|
if self_user := self._config.session.user:
|
||||||
self._dc_id = self_user.dc
|
self._dc_id = self_user.dc
|
||||||
if config.catch_up and config.session.state:
|
if self._config.catch_up and self._config.session.state:
|
||||||
self._message_box.load(config.session.state)
|
self._message_box.load(self._config.session.state)
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
|
@ -450,15 +465,6 @@ class Client:
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
return connected(self)
|
return connected(self)
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self) -> Session:
|
|
||||||
"""
|
|
||||||
Up-to-date session state, useful for persisting it to storage.
|
|
||||||
|
|
||||||
Mutating the returned object may cause the library to misbehave.
|
|
||||||
"""
|
|
||||||
return session(self)
|
|
||||||
|
|
||||||
def _build_message_map(
|
def _build_message_map(
|
||||||
self,
|
self,
|
||||||
result: abcs.Updates,
|
result: abcs.Updates,
|
||||||
|
|
|
@ -122,6 +122,8 @@ async def connect(self: Client) -> None:
|
||||||
if self._sender:
|
if self._sender:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if session := await self._storage.load():
|
||||||
|
self._config.session = session
|
||||||
self._sender = await connect_sender(self._dc_id, self._config)
|
self._sender = await connect_sender(self._dc_id, self._config)
|
||||||
|
|
||||||
if self._message_box.is_empty() and self._config.session.user:
|
if self._message_box.is_empty() and self._config.session.user:
|
||||||
|
@ -148,6 +150,9 @@ async def disconnect(self: Client) -> None:
|
||||||
await self._sender.disconnect()
|
await self._sender.disconnect()
|
||||||
self._sender = None
|
self._sender = None
|
||||||
|
|
||||||
|
self._config.session.state = self._message_box.session_state()
|
||||||
|
await self._storage.save(self._config.session)
|
||||||
|
|
||||||
|
|
||||||
async def invoke_request(
|
async def invoke_request(
|
||||||
client: Client,
|
client: Client,
|
||||||
|
|
|
@ -3,18 +3,15 @@ from .message_box import (
|
||||||
BOT_CHANNEL_DIFF_LIMIT,
|
BOT_CHANNEL_DIFF_LIMIT,
|
||||||
NO_UPDATES_TIMEOUT,
|
NO_UPDATES_TIMEOUT,
|
||||||
USER_CHANNEL_DIFF_LIMIT,
|
USER_CHANNEL_DIFF_LIMIT,
|
||||||
ChannelState,
|
|
||||||
DataCenter,
|
|
||||||
Gap,
|
Gap,
|
||||||
MessageBox,
|
MessageBox,
|
||||||
PossibleGap,
|
PossibleGap,
|
||||||
PrematureEndReason,
|
PrematureEndReason,
|
||||||
PtsInfo,
|
PtsInfo,
|
||||||
Session,
|
|
||||||
State,
|
State,
|
||||||
UpdateState,
|
|
||||||
User,
|
|
||||||
)
|
)
|
||||||
|
from .session import ChannelState, DataCenter, Session, UpdateState, User
|
||||||
|
from .storage import MemorySession, SqliteSession, Storage
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatHashCache",
|
"ChatHashCache",
|
||||||
|
@ -23,15 +20,18 @@ __all__ = [
|
||||||
"BOT_CHANNEL_DIFF_LIMIT",
|
"BOT_CHANNEL_DIFF_LIMIT",
|
||||||
"NO_UPDATES_TIMEOUT",
|
"NO_UPDATES_TIMEOUT",
|
||||||
"USER_CHANNEL_DIFF_LIMIT",
|
"USER_CHANNEL_DIFF_LIMIT",
|
||||||
"ChannelState",
|
|
||||||
"DataCenter",
|
|
||||||
"Gap",
|
"Gap",
|
||||||
"PossibleGap",
|
"PossibleGap",
|
||||||
"PrematureEndReason",
|
"PrematureEndReason",
|
||||||
"PtsInfo",
|
"PtsInfo",
|
||||||
"Session",
|
|
||||||
"State",
|
"State",
|
||||||
|
"ChannelState",
|
||||||
|
"DataCenter",
|
||||||
|
"Session",
|
||||||
"UpdateState",
|
"UpdateState",
|
||||||
"User",
|
"User",
|
||||||
"MessageBox",
|
"MessageBox",
|
||||||
|
"MemorySession",
|
||||||
|
"SqliteSession",
|
||||||
|
"Storage",
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,16 +2,11 @@ from .defs import (
|
||||||
BOT_CHANNEL_DIFF_LIMIT,
|
BOT_CHANNEL_DIFF_LIMIT,
|
||||||
NO_UPDATES_TIMEOUT,
|
NO_UPDATES_TIMEOUT,
|
||||||
USER_CHANNEL_DIFF_LIMIT,
|
USER_CHANNEL_DIFF_LIMIT,
|
||||||
ChannelState,
|
|
||||||
DataCenter,
|
|
||||||
Gap,
|
Gap,
|
||||||
PossibleGap,
|
PossibleGap,
|
||||||
PrematureEndReason,
|
PrematureEndReason,
|
||||||
PtsInfo,
|
PtsInfo,
|
||||||
Session,
|
|
||||||
State,
|
State,
|
||||||
UpdateState,
|
|
||||||
User,
|
|
||||||
)
|
)
|
||||||
from .messagebox import MessageBox
|
from .messagebox import MessageBox
|
||||||
|
|
||||||
|
@ -19,15 +14,10 @@ __all__ = [
|
||||||
"BOT_CHANNEL_DIFF_LIMIT",
|
"BOT_CHANNEL_DIFF_LIMIT",
|
||||||
"NO_UPDATES_TIMEOUT",
|
"NO_UPDATES_TIMEOUT",
|
||||||
"USER_CHANNEL_DIFF_LIMIT",
|
"USER_CHANNEL_DIFF_LIMIT",
|
||||||
"ChannelState",
|
|
||||||
"DataCenter",
|
|
||||||
"Gap",
|
"Gap",
|
||||||
"PossibleGap",
|
"PossibleGap",
|
||||||
"PrematureEndReason",
|
"PrematureEndReason",
|
||||||
"PtsInfo",
|
"PtsInfo",
|
||||||
"Session",
|
|
||||||
"State",
|
"State",
|
||||||
"UpdateState",
|
|
||||||
"User",
|
|
||||||
"MessageBox",
|
"MessageBox",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,144 +1,10 @@
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Self, Union
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from ...tl import abcs
|
from ...tl import abcs
|
||||||
|
|
||||||
|
|
||||||
class DataCenter:
|
|
||||||
__slots__ = ("id", "addr", "auth")
|
|
||||||
|
|
||||||
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
|
|
||||||
self.id = id
|
|
||||||
self.addr = addr
|
|
||||||
self.auth = auth
|
|
||||||
|
|
||||||
|
|
||||||
class User:
|
|
||||||
__slots__ = ("id", "dc", "bot")
|
|
||||||
|
|
||||||
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
|
||||||
self.id = id
|
|
||||||
self.dc = dc
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelState:
|
|
||||||
__slots__ = ("id", "pts")
|
|
||||||
|
|
||||||
def __init__(self, *, id: int, pts: int) -> None:
|
|
||||||
self.id = id
|
|
||||||
self.pts = pts
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateState:
|
|
||||||
__slots__ = (
|
|
||||||
"pts",
|
|
||||||
"qts",
|
|
||||||
"date",
|
|
||||||
"seq",
|
|
||||||
"channels",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
pts: int,
|
|
||||||
qts: int,
|
|
||||||
date: int,
|
|
||||||
seq: int,
|
|
||||||
channels: List[ChannelState],
|
|
||||||
) -> None:
|
|
||||||
self.pts = pts
|
|
||||||
self.qts = qts
|
|
||||||
self.date = date
|
|
||||||
self.seq = seq
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
|
||||||
__slots__ = ("dcs", "user", "state")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
dcs: Optional[List[DataCenter]] = None,
|
|
||||||
user: Optional[User] = None,
|
|
||||||
state: Optional[UpdateState] = None,
|
|
||||||
):
|
|
||||||
self.dcs = dcs or []
|
|
||||||
self.user = user
|
|
||||||
self.state = state
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"dcs": [
|
|
||||||
{
|
|
||||||
"id": dc.id,
|
|
||||||
"addr": dc.addr,
|
|
||||||
"auth": base64.b64encode(dc.auth).decode("ascii")
|
|
||||||
if dc.auth
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
for dc in self.dcs
|
|
||||||
],
|
|
||||||
"user": {
|
|
||||||
"id": self.user.id,
|
|
||||||
"dc": self.user.dc,
|
|
||||||
"bot": self.user.bot,
|
|
||||||
}
|
|
||||||
if self.user
|
|
||||||
else None,
|
|
||||||
"state": {
|
|
||||||
"pts": self.state.pts,
|
|
||||||
"qts": self.state.qts,
|
|
||||||
"date": self.state.date,
|
|
||||||
"seq": self.state.seq,
|
|
||||||
"channels": [
|
|
||||||
{"id": channel.id, "pts": channel.pts}
|
|
||||||
for channel in self.state.channels
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if self.state
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, dict: Dict[str, Any]) -> Self:
|
|
||||||
return cls(
|
|
||||||
dcs=[
|
|
||||||
DataCenter(
|
|
||||||
id=dc["id"],
|
|
||||||
addr=dc["addr"],
|
|
||||||
auth=base64.b64decode(dc["auth"])
|
|
||||||
if dc["auth"] is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
for dc in dict["dcs"]
|
|
||||||
],
|
|
||||||
user=User(
|
|
||||||
id=dict["user"]["id"],
|
|
||||||
dc=dict["user"]["dc"],
|
|
||||||
bot=dict["user"]["bot"],
|
|
||||||
)
|
|
||||||
if dict["user"]
|
|
||||||
else None,
|
|
||||||
state=UpdateState(
|
|
||||||
pts=dict["state"]["pts"],
|
|
||||||
qts=dict["state"]["qts"],
|
|
||||||
date=dict["state"]["date"],
|
|
||||||
seq=dict["state"]["seq"],
|
|
||||||
channels=[
|
|
||||||
ChannelState(id=channel["id"], pts=channel["pts"])
|
|
||||||
for channel in dict["state"]["channels"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if dict["state"]
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PtsInfo:
|
class PtsInfo:
|
||||||
__slots__ = ("pts", "pts_count", "entry")
|
__slots__ = ("pts", "pts_count", "entry")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from ...tl import Request, abcs, functions, types
|
from ...tl import Request, abcs, functions, types
|
||||||
from ..chat import ChatHashCache
|
from ..chat import ChatHashCache
|
||||||
|
from ..session import ChannelState, UpdateState
|
||||||
from .adaptor import adapt, pts_info_from_update
|
from .adaptor import adapt, pts_info_from_update
|
||||||
from .defs import (
|
from .defs import (
|
||||||
BOT_CHANNEL_DIFF_LIMIT,
|
BOT_CHANNEL_DIFF_LIMIT,
|
||||||
|
@ -17,13 +18,11 @@ from .defs import (
|
||||||
NO_UPDATES_TIMEOUT,
|
NO_UPDATES_TIMEOUT,
|
||||||
POSSIBLE_GAP_TIMEOUT,
|
POSSIBLE_GAP_TIMEOUT,
|
||||||
USER_CHANNEL_DIFF_LIMIT,
|
USER_CHANNEL_DIFF_LIMIT,
|
||||||
ChannelState,
|
|
||||||
Entry,
|
Entry,
|
||||||
Gap,
|
Gap,
|
||||||
PossibleGap,
|
PossibleGap,
|
||||||
PrematureEndReason,
|
PrematureEndReason,
|
||||||
State,
|
State,
|
||||||
UpdateState,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
159
client/src/telethon/_impl/session/session.py
Normal file
159
client/src/telethon/_impl/session/session.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
import base64
|
||||||
|
from typing import Any, Dict, List, Optional, Self
|
||||||
|
|
||||||
|
|
||||||
|
class DataCenter:
|
||||||
|
__slots__ = ("id", "addr", "auth")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.addr = addr
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
__slots__ = ("id", "dc", "bot")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, dc: int, bot: bool) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.dc = dc
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelState:
|
||||||
|
__slots__ = ("id", "pts")
|
||||||
|
|
||||||
|
def __init__(self, *, id: int, pts: int) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.pts = pts
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateState:
|
||||||
|
__slots__ = (
|
||||||
|
"pts",
|
||||||
|
"qts",
|
||||||
|
"date",
|
||||||
|
"seq",
|
||||||
|
"channels",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pts: int,
|
||||||
|
qts: int,
|
||||||
|
date: int,
|
||||||
|
seq: int,
|
||||||
|
channels: List[ChannelState],
|
||||||
|
) -> None:
|
||||||
|
self.pts = pts
|
||||||
|
self.qts = qts
|
||||||
|
self.date = date
|
||||||
|
self.seq = seq
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
"""
|
||||||
|
A Telethon session.
|
||||||
|
|
||||||
|
A `Session` instance contains the required information to login into your
|
||||||
|
Telegram account. **Never** give the saved session file to anyone else or
|
||||||
|
make it public.
|
||||||
|
|
||||||
|
Leaking the session file will grant a bad actor complete access to your
|
||||||
|
account, including private conversations, groups you're part of and list
|
||||||
|
of contacts (though not secret chats).
|
||||||
|
|
||||||
|
If you think the session has been compromised, immediately terminate all
|
||||||
|
sessions through an official Telegram client to revoke the authorization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
__slots__ = ("dcs", "user", "state")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dcs: Optional[List[DataCenter]] = None,
|
||||||
|
user: Optional[User] = None,
|
||||||
|
state: Optional[UpdateState] = None,
|
||||||
|
):
|
||||||
|
self.dcs = dcs or []
|
||||||
|
self.user = user
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"v": self.VERSION,
|
||||||
|
"dcs": [
|
||||||
|
{
|
||||||
|
"id": dc.id,
|
||||||
|
"addr": dc.addr,
|
||||||
|
"auth": base64.b64encode(dc.auth).decode("ascii")
|
||||||
|
if dc.auth
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
for dc in self.dcs
|
||||||
|
],
|
||||||
|
"user": {
|
||||||
|
"id": self.user.id,
|
||||||
|
"dc": self.user.dc,
|
||||||
|
"bot": self.user.bot,
|
||||||
|
}
|
||||||
|
if self.user
|
||||||
|
else None,
|
||||||
|
"state": {
|
||||||
|
"pts": self.state.pts,
|
||||||
|
"qts": self.state.qts,
|
||||||
|
"date": self.state.date,
|
||||||
|
"seq": self.state.seq,
|
||||||
|
"channels": [
|
||||||
|
{"id": channel.id, "pts": channel.pts}
|
||||||
|
for channel in self.state.channels
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if self.state
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dict: Dict[str, Any]) -> Self:
|
||||||
|
version = dict["v"]
|
||||||
|
if version != cls.VERSION:
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot parse session format version {version} (expected {cls.VERSION})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
dcs=[
|
||||||
|
DataCenter(
|
||||||
|
id=dc["id"],
|
||||||
|
addr=dc["addr"],
|
||||||
|
auth=base64.b64decode(dc["auth"])
|
||||||
|
if dc["auth"] is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
for dc in dict["dcs"]
|
||||||
|
],
|
||||||
|
user=User(
|
||||||
|
id=dict["user"]["id"],
|
||||||
|
dc=dict["user"]["dc"],
|
||||||
|
bot=dict["user"]["bot"],
|
||||||
|
)
|
||||||
|
if dict["user"]
|
||||||
|
else None,
|
||||||
|
state=UpdateState(
|
||||||
|
pts=dict["state"]["pts"],
|
||||||
|
qts=dict["state"]["qts"],
|
||||||
|
date=dict["state"]["date"],
|
||||||
|
seq=dict["state"]["seq"],
|
||||||
|
channels=[
|
||||||
|
ChannelState(id=channel["id"], pts=channel["pts"])
|
||||||
|
for channel in dict["state"]["channels"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if dict["state"]
|
||||||
|
else None,
|
||||||
|
)
|
15
client/src/telethon/_impl/session/storage/__init__.py
Normal file
15
client/src/telethon/_impl/session/storage/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .memory import MemorySession
|
||||||
|
from .storage import Storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .sqlite import SqliteSession
|
||||||
|
except ImportError as e:
|
||||||
|
|
||||||
|
class SqliteSession(Storage): # type: ignore [no-redef]
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MemorySession", "Storage", "SqliteSession"]
|
28
client/src/telethon/_impl/session/storage/memory.py
Normal file
28
client/src/telethon/_impl/session/storage/memory.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..session import Session
|
||||||
|
from .storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySession(Storage):
|
||||||
|
"""
|
||||||
|
Session storage without persistence.
|
||||||
|
|
||||||
|
This is the simplest storage and is the one used by default.
|
||||||
|
|
||||||
|
Session data is only kept in memory and is not persisted to disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("session",)
|
||||||
|
|
||||||
|
def __init__(self, session: Optional[Session] = None):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def load(self) -> Optional[Session]:
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def save(self, session: Session) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
self.session = None
|
202
client/src/telethon/_impl/session/storage/sqlite.py
Normal file
202
client/src/telethon/_impl/session/storage/sqlite.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ..session import ChannelState, DataCenter, Session, UpdateState, User
|
||||||
|
from .storage import Storage
|
||||||
|
|
||||||
|
EXTENSION = ".session"
|
||||||
|
CURRENT_VERSION = 10
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSession(Storage):
|
||||||
|
"""
|
||||||
|
Session storage backed by SQLite.
|
||||||
|
|
||||||
|
SQLite is a reliable way to persist data to disk and offers file locking.
|
||||||
|
|
||||||
|
Paths without extension will have '.session' appended to them.
|
||||||
|
This is by convention, and to make it harder to commit session files to
|
||||||
|
an VCS by accident (adding `*.session` to `.gitignore` will catch them).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: Union[str, Path]):
|
||||||
|
path = Path(file)
|
||||||
|
if not path.suffix:
|
||||||
|
path = path.with_suffix(EXTENSION)
|
||||||
|
|
||||||
|
self._path = path
|
||||||
|
self._conn: Optional[sqlite3.Connection] = None
|
||||||
|
|
||||||
|
async def load(self) -> Optional[Session]:
|
||||||
|
conn = self._current_conn()
|
||||||
|
|
||||||
|
c = conn.cursor()
|
||||||
|
with conn:
|
||||||
|
version = self._get_or_init_version(c)
|
||||||
|
if version < CURRENT_VERSION:
|
||||||
|
if version == 7:
|
||||||
|
session = self._load_v7(c)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self._reset(c)
|
||||||
|
self._get_or_init_version(c)
|
||||||
|
self._save_v10(c, session)
|
||||||
|
|
||||||
|
return self._load_v10(c)
|
||||||
|
|
||||||
|
async def save(self, session: Session) -> None:
|
||||||
|
conn = self._current_conn()
|
||||||
|
with conn:
|
||||||
|
self._save_v10(conn.cursor(), session)
|
||||||
|
conn.close()
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
if self._conn:
|
||||||
|
self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
self._path.unlink()
|
||||||
|
|
||||||
|
def _current_conn(self) -> sqlite3.Connection:
|
||||||
|
if self._conn is None:
|
||||||
|
self._conn = sqlite3.connect(self._path)
|
||||||
|
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_v7(c: sqlite3.Cursor) -> Session:
|
||||||
|
# Session v7 format from telethon v1
|
||||||
|
c.execute("select dc_id, server_address, port, auth_key from sessions")
|
||||||
|
sessions = c.fetchall()
|
||||||
|
c.execute("select pts, qts, date, seq from update_state where id = 0")
|
||||||
|
state = c.fetchone()
|
||||||
|
c.execute("select id, pts from update_state where id != 0")
|
||||||
|
channelstate = c.fetchall()
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
dcs=[
|
||||||
|
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
|
||||||
|
for (id, ip, port, auth) in sessions
|
||||||
|
],
|
||||||
|
user=None,
|
||||||
|
state=UpdateState(
|
||||||
|
pts=state[0],
|
||||||
|
qts=state[1],
|
||||||
|
date=state[2],
|
||||||
|
seq=state[3],
|
||||||
|
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_v10(c: sqlite3.Cursor) -> Session:
|
||||||
|
c.execute("select * from datacenter")
|
||||||
|
datacenter = c.fetchall()
|
||||||
|
c.execute("select * from user")
|
||||||
|
user = c.fetchone()
|
||||||
|
c.execute("select * from state")
|
||||||
|
state = c.fetchone()
|
||||||
|
c.execute("select * from channelstate")
|
||||||
|
channelstate = c.fetchall()
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
dcs=[
|
||||||
|
DataCenter(id=id, addr=addr, auth=auth)
|
||||||
|
for (id, addr, auth) in datacenter
|
||||||
|
],
|
||||||
|
user=User(id=user[0], dc=user[1], bot=bool(user[2])) if user else None,
|
||||||
|
state=UpdateState(
|
||||||
|
pts=state[0],
|
||||||
|
qts=state[1],
|
||||||
|
date=state[2],
|
||||||
|
seq=state[3],
|
||||||
|
channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
|
||||||
|
c.execute("delete from datacenter")
|
||||||
|
c.execute("delete from user")
|
||||||
|
c.execute("delete from state")
|
||||||
|
c.execute("delete from channelstate")
|
||||||
|
c.executemany(
|
||||||
|
"insert into datacenter values (?, ?, ?)",
|
||||||
|
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
|
||||||
|
)
|
||||||
|
if user := session.user:
|
||||||
|
c.execute(
|
||||||
|
"insert into user values (?, ?, ?)", (user.id, user.dc, int(user.bot))
|
||||||
|
)
|
||||||
|
if state := session.state:
|
||||||
|
c.execute(
|
||||||
|
"insert into state values (?, ?, ?, ?)",
|
||||||
|
(state.pts, state.qts, state.date, state.seq),
|
||||||
|
)
|
||||||
|
c.executemany(
|
||||||
|
"insert into channelstate values (?, ?)",
|
||||||
|
[(channel.id, channel.pts) for channel in state.channels],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reset(c: sqlite3.Cursor) -> None:
|
||||||
|
safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
|
c.execute("select name from sqlite_master where type='table'")
|
||||||
|
for (name,) in c.fetchall():
|
||||||
|
# Can't format arguments for table names. Regardless, it shouldn't
|
||||||
|
# be an SQL-injection because names come from `sqlite_master`.
|
||||||
|
# Just to be on the safe-side, check for r'\w+' nevertheless,
|
||||||
|
# avoiding referencing globals which could've been monkey-patched.
|
||||||
|
for char in name:
|
||||||
|
if char not in safe_chars or name.__len__() > 20:
|
||||||
|
raise ValueError(f"potentially unsafe table name: {name}")
|
||||||
|
|
||||||
|
c.execute(f"drop table {name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_or_init_version(c: sqlite3.Cursor) -> int:
|
||||||
|
c.execute(
|
||||||
|
"select name from sqlite_master where type='table' and name='version'"
|
||||||
|
)
|
||||||
|
if c.fetchone():
|
||||||
|
c.execute("select version from version")
|
||||||
|
res = c.fetchone()[0]
|
||||||
|
assert isinstance(res, int)
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
SqliteSession._create_tables(c)
|
||||||
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||||
|
return CURRENT_VERSION
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_tables(c: sqlite3.Cursor) -> None:
|
||||||
|
c.executescript(
|
||||||
|
"""
|
||||||
|
create table version (
|
||||||
|
version integer primary key
|
||||||
|
);
|
||||||
|
create table datacenter(
|
||||||
|
id integer primary key,
|
||||||
|
addr text not null,
|
||||||
|
auth blob
|
||||||
|
);
|
||||||
|
create table user(
|
||||||
|
id integer primary key,
|
||||||
|
dc integer not null,
|
||||||
|
bot integer not null
|
||||||
|
);
|
||||||
|
create table state(
|
||||||
|
pts integer not null,
|
||||||
|
qts integer not null,
|
||||||
|
date integer not null,
|
||||||
|
seq integer not null
|
||||||
|
);
|
||||||
|
create table channelstate(
|
||||||
|
id integer primary key,
|
||||||
|
pts integer not null
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
30
client/src/telethon/_impl/session/storage/storage.py
Normal file
30
client/src/telethon/_impl/session/storage/storage.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import abc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..session import Session
|
||||||
|
|
||||||
|
|
||||||
|
class Storage(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def load(self) -> Optional[Session]:
|
||||||
|
"""
|
||||||
|
Load the `Session` instance, if any.
|
||||||
|
|
||||||
|
This method is called by the library prior to `connect`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def save(self, session: Session) -> None:
|
||||||
|
"""
|
||||||
|
Save the `Session` instance to persistent storage.
|
||||||
|
|
||||||
|
This method is called by the library post `disconnect`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete the saved `Session`.
|
||||||
|
|
||||||
|
This method is called by the library post `log_out`.
|
||||||
|
"""
|
|
@ -13,13 +13,7 @@ async def test_ping_pong() -> None:
|
||||||
api_hash = os.getenv("TG_HASH")
|
api_hash = os.getenv("TG_HASH")
|
||||||
assert api_id and api_id.isdigit()
|
assert api_id and api_id.isdigit()
|
||||||
assert api_hash
|
assert api_hash
|
||||||
client = Client(
|
client = Client(None, int(api_id), api_hash)
|
||||||
Config(
|
|
||||||
session=Session(),
|
|
||||||
api_id=int(api_id),
|
|
||||||
api_hash=api_hash,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert not client.connected
|
assert not client.connected
|
||||||
await client.connect()
|
await client.connect()
|
||||||
assert client.connected
|
assert client.connected
|
||||||
|
|
|
@ -81,7 +81,6 @@ def test_unpack_two_at_once() -> None:
|
||||||
with raises(ValueError) as e:
|
with raises(ValueError) as e:
|
||||||
transport.unpack(input, output)
|
transport.unpack(input, output)
|
||||||
e.match("bad seq")
|
e.match("bad seq")
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
|
|
||||||
def test_unpack_twice() -> None:
|
def test_unpack_twice() -> None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user