mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Add session storages
This commit is contained in:
		
							parent
							
								
									16de3b274c
								
							
						
					
					
						commit
						aa83e7b043
					
				| 
						 | 
				
			
			@ -141,7 +141,4 @@ async def check_password(
 | 
			
		|||
async def sign_out(self: Client) -> None:
 | 
			
		||||
    await self(functions.auth.log_out())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def session(client: Client) -> Session:
 | 
			
		||||
    client._config.session.state = client._message_box.session_state()
 | 
			
		||||
    return client._config.session
 | 
			
		||||
    await self._storage.delete()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,13 +19,20 @@ from typing import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
from ...mtsender import Sender
 | 
			
		||||
from ...session import ChatHashCache, MessageBox, PackedChat, Session
 | 
			
		||||
from ...session import (
 | 
			
		||||
    ChatHashCache,
 | 
			
		||||
    MemorySession,
 | 
			
		||||
    MessageBox,
 | 
			
		||||
    PackedChat,
 | 
			
		||||
    Session,
 | 
			
		||||
    SqliteSession,
 | 
			
		||||
    Storage,
 | 
			
		||||
)
 | 
			
		||||
from ...tl import Request, abcs
 | 
			
		||||
from ..events import Event
 | 
			
		||||
from ..events.filters import Filter
 | 
			
		||||
from ..types import (
 | 
			
		||||
    AsyncList,
 | 
			
		||||
    Chat,
 | 
			
		||||
    ChatLike,
 | 
			
		||||
    File,
 | 
			
		||||
    InFileLike,
 | 
			
		||||
| 
						 | 
				
			
			@ -41,15 +48,12 @@ from .auth import (
 | 
			
		|||
    check_password,
 | 
			
		||||
    is_authorized,
 | 
			
		||||
    request_login_code,
 | 
			
		||||
    session,
 | 
			
		||||
    sign_in,
 | 
			
		||||
    sign_out,
 | 
			
		||||
)
 | 
			
		||||
from .bots import InlineResult, inline_query
 | 
			
		||||
from .chats import (
 | 
			
		||||
    get_participants,
 | 
			
		||||
)
 | 
			
		||||
from .dialogs import get_dialogs, delete_dialog
 | 
			
		||||
from .chats import get_participants
 | 
			
		||||
from .dialogs import delete_dialog, get_dialogs
 | 
			
		||||
from .files import (
 | 
			
		||||
    download,
 | 
			
		||||
    iter_download,
 | 
			
		||||
| 
						 | 
				
			
			@ -88,38 +92,49 @@ from .updates import (
 | 
			
		|||
    remove_event_handler,
 | 
			
		||||
    set_handler_filter,
 | 
			
		||||
)
 | 
			
		||||
from .users import (
 | 
			
		||||
    get_me,
 | 
			
		||||
    input_to_peer,
 | 
			
		||||
    resolve_to_packed,
 | 
			
		||||
)
 | 
			
		||||
from .users import get_me, input_to_peer, resolve_to_packed
 | 
			
		||||
 | 
			
		||||
Return = TypeVar("Return")
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Client:
 | 
			
		||||
    def __init__(self, config: Config) -> None:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        session: Optional[Union[str, Path, Storage]],
 | 
			
		||||
        api_id: int,
 | 
			
		||||
        api_hash: Optional[str] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._sender: Optional[Sender] = None
 | 
			
		||||
        self._sender_lock = asyncio.Lock()
 | 
			
		||||
        self._dc_id = DEFAULT_DC
 | 
			
		||||
        self._config = config
 | 
			
		||||
        if isinstance(session, Storage):
 | 
			
		||||
            self._storage = session
 | 
			
		||||
        elif session is None:
 | 
			
		||||
            self._storage = MemorySession()
 | 
			
		||||
        else:
 | 
			
		||||
            self._storage = SqliteSession(session)
 | 
			
		||||
        self._config = Config(
 | 
			
		||||
            session=Session(),
 | 
			
		||||
            api_id=api_id,
 | 
			
		||||
            api_hash=api_hash or "",
 | 
			
		||||
        )
 | 
			
		||||
        self._message_box = MessageBox()
 | 
			
		||||
        self._chat_hashes = ChatHashCache(None)
 | 
			
		||||
        self._last_update_limit_warn: Optional[float] = None
 | 
			
		||||
        self._updates: asyncio.Queue[
 | 
			
		||||
            Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
 | 
			
		||||
        ] = asyncio.Queue(maxsize=config.update_queue_limit or 0)
 | 
			
		||||
        ] = asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
 | 
			
		||||
        self._dispatcher: Optional[asyncio.Task[None]] = None
 | 
			
		||||
        self._downloader_map = object()
 | 
			
		||||
        self._handlers: Dict[
 | 
			
		||||
            Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
 | 
			
		||||
        ] = {}
 | 
			
		||||
 | 
			
		||||
        if self_user := config.session.user:
 | 
			
		||||
        if self_user := self._config.session.user:
 | 
			
		||||
            self._dc_id = self_user.dc
 | 
			
		||||
            if config.catch_up and config.session.state:
 | 
			
		||||
                self._message_box.load(config.session.state)
 | 
			
		||||
            if self._config.catch_up and self._config.session.state:
 | 
			
		||||
                self._message_box.load(self._config.session.state)
 | 
			
		||||
 | 
			
		||||
    # ---
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -450,15 +465,6 @@ class Client:
 | 
			
		|||
    def connected(self) -> bool:
 | 
			
		||||
        return connected(self)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def session(self) -> Session:
 | 
			
		||||
        """
 | 
			
		||||
        Up-to-date session state, useful for persisting it to storage.
 | 
			
		||||
 | 
			
		||||
        Mutating the returned object may cause the library to misbehave.
 | 
			
		||||
        """
 | 
			
		||||
        return session(self)
 | 
			
		||||
 | 
			
		||||
    def _build_message_map(
 | 
			
		||||
        self,
 | 
			
		||||
        result: abcs.Updates,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -122,6 +122,8 @@ async def connect(self: Client) -> None:
 | 
			
		|||
    if self._sender:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if session := await self._storage.load():
 | 
			
		||||
        self._config.session = session
 | 
			
		||||
    self._sender = await connect_sender(self._dc_id, self._config)
 | 
			
		||||
 | 
			
		||||
    if self._message_box.is_empty() and self._config.session.user:
 | 
			
		||||
| 
						 | 
				
			
			@ -148,6 +150,9 @@ async def disconnect(self: Client) -> None:
 | 
			
		|||
    await self._sender.disconnect()
 | 
			
		||||
    self._sender = None
 | 
			
		||||
 | 
			
		||||
    self._config.session.state = self._message_box.session_state()
 | 
			
		||||
    await self._storage.save(self._config.session)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def invoke_request(
 | 
			
		||||
    client: Client,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,18 +3,15 @@ from .message_box import (
 | 
			
		|||
    BOT_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
    NO_UPDATES_TIMEOUT,
 | 
			
		||||
    USER_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
    ChannelState,
 | 
			
		||||
    DataCenter,
 | 
			
		||||
    Gap,
 | 
			
		||||
    MessageBox,
 | 
			
		||||
    PossibleGap,
 | 
			
		||||
    PrematureEndReason,
 | 
			
		||||
    PtsInfo,
 | 
			
		||||
    Session,
 | 
			
		||||
    State,
 | 
			
		||||
    UpdateState,
 | 
			
		||||
    User,
 | 
			
		||||
)
 | 
			
		||||
from .session import ChannelState, DataCenter, Session, UpdateState, User
 | 
			
		||||
from .storage import MemorySession, SqliteSession, Storage
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "ChatHashCache",
 | 
			
		||||
| 
						 | 
				
			
			@ -23,15 +20,18 @@ __all__ = [
 | 
			
		|||
    "BOT_CHANNEL_DIFF_LIMIT",
 | 
			
		||||
    "NO_UPDATES_TIMEOUT",
 | 
			
		||||
    "USER_CHANNEL_DIFF_LIMIT",
 | 
			
		||||
    "ChannelState",
 | 
			
		||||
    "DataCenter",
 | 
			
		||||
    "Gap",
 | 
			
		||||
    "PossibleGap",
 | 
			
		||||
    "PrematureEndReason",
 | 
			
		||||
    "PtsInfo",
 | 
			
		||||
    "Session",
 | 
			
		||||
    "State",
 | 
			
		||||
    "ChannelState",
 | 
			
		||||
    "DataCenter",
 | 
			
		||||
    "Session",
 | 
			
		||||
    "UpdateState",
 | 
			
		||||
    "User",
 | 
			
		||||
    "MessageBox",
 | 
			
		||||
    "MemorySession",
 | 
			
		||||
    "SqliteSession",
 | 
			
		||||
    "Storage",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,16 +2,11 @@ from .defs import (
 | 
			
		|||
    BOT_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
    NO_UPDATES_TIMEOUT,
 | 
			
		||||
    USER_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
    ChannelState,
 | 
			
		||||
    DataCenter,
 | 
			
		||||
    Gap,
 | 
			
		||||
    PossibleGap,
 | 
			
		||||
    PrematureEndReason,
 | 
			
		||||
    PtsInfo,
 | 
			
		||||
    Session,
 | 
			
		||||
    State,
 | 
			
		||||
    UpdateState,
 | 
			
		||||
    User,
 | 
			
		||||
)
 | 
			
		||||
from .messagebox import MessageBox
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -19,15 +14,10 @@ __all__ = [
 | 
			
		|||
    "BOT_CHANNEL_DIFF_LIMIT",
 | 
			
		||||
    "NO_UPDATES_TIMEOUT",
 | 
			
		||||
    "USER_CHANNEL_DIFF_LIMIT",
 | 
			
		||||
    "ChannelState",
 | 
			
		||||
    "DataCenter",
 | 
			
		||||
    "Gap",
 | 
			
		||||
    "PossibleGap",
 | 
			
		||||
    "PrematureEndReason",
 | 
			
		||||
    "PtsInfo",
 | 
			
		||||
    "Session",
 | 
			
		||||
    "State",
 | 
			
		||||
    "UpdateState",
 | 
			
		||||
    "User",
 | 
			
		||||
    "MessageBox",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,144 +1,10 @@
 | 
			
		|||
import base64
 | 
			
		||||
import logging
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Dict, List, Literal, Optional, Self, Union
 | 
			
		||||
from typing import List, Literal, Union
 | 
			
		||||
 | 
			
		||||
from ...tl import abcs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCenter:
 | 
			
		||||
    __slots__ = ("id", "addr", "auth")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.addr = addr
 | 
			
		||||
        self.auth = auth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class User:
 | 
			
		||||
    __slots__ = ("id", "dc", "bot")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, dc: int, bot: bool) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.dc = dc
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChannelState:
 | 
			
		||||
    __slots__ = ("id", "pts")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, pts: int) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.pts = pts
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpdateState:
 | 
			
		||||
    __slots__ = (
 | 
			
		||||
        "pts",
 | 
			
		||||
        "qts",
 | 
			
		||||
        "date",
 | 
			
		||||
        "seq",
 | 
			
		||||
        "channels",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        pts: int,
 | 
			
		||||
        qts: int,
 | 
			
		||||
        date: int,
 | 
			
		||||
        seq: int,
 | 
			
		||||
        channels: List[ChannelState],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.pts = pts
 | 
			
		||||
        self.qts = qts
 | 
			
		||||
        self.date = date
 | 
			
		||||
        self.seq = seq
 | 
			
		||||
        self.channels = channels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Session:
 | 
			
		||||
    __slots__ = ("dcs", "user", "state")
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        dcs: Optional[List[DataCenter]] = None,
 | 
			
		||||
        user: Optional[User] = None,
 | 
			
		||||
        state: Optional[UpdateState] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.dcs = dcs or []
 | 
			
		||||
        self.user = user
 | 
			
		||||
        self.state = state
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        return {
 | 
			
		||||
            "dcs": [
 | 
			
		||||
                {
 | 
			
		||||
                    "id": dc.id,
 | 
			
		||||
                    "addr": dc.addr,
 | 
			
		||||
                    "auth": base64.b64encode(dc.auth).decode("ascii")
 | 
			
		||||
                    if dc.auth
 | 
			
		||||
                    else None,
 | 
			
		||||
                }
 | 
			
		||||
                for dc in self.dcs
 | 
			
		||||
            ],
 | 
			
		||||
            "user": {
 | 
			
		||||
                "id": self.user.id,
 | 
			
		||||
                "dc": self.user.dc,
 | 
			
		||||
                "bot": self.user.bot,
 | 
			
		||||
            }
 | 
			
		||||
            if self.user
 | 
			
		||||
            else None,
 | 
			
		||||
            "state": {
 | 
			
		||||
                "pts": self.state.pts,
 | 
			
		||||
                "qts": self.state.qts,
 | 
			
		||||
                "date": self.state.date,
 | 
			
		||||
                "seq": self.state.seq,
 | 
			
		||||
                "channels": [
 | 
			
		||||
                    {"id": channel.id, "pts": channel.pts}
 | 
			
		||||
                    for channel in self.state.channels
 | 
			
		||||
                ],
 | 
			
		||||
            }
 | 
			
		||||
            if self.state
 | 
			
		||||
            else None,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, dict: Dict[str, Any]) -> Self:
 | 
			
		||||
        return cls(
 | 
			
		||||
            dcs=[
 | 
			
		||||
                DataCenter(
 | 
			
		||||
                    id=dc["id"],
 | 
			
		||||
                    addr=dc["addr"],
 | 
			
		||||
                    auth=base64.b64decode(dc["auth"])
 | 
			
		||||
                    if dc["auth"] is not None
 | 
			
		||||
                    else None,
 | 
			
		||||
                )
 | 
			
		||||
                for dc in dict["dcs"]
 | 
			
		||||
            ],
 | 
			
		||||
            user=User(
 | 
			
		||||
                id=dict["user"]["id"],
 | 
			
		||||
                dc=dict["user"]["dc"],
 | 
			
		||||
                bot=dict["user"]["bot"],
 | 
			
		||||
            )
 | 
			
		||||
            if dict["user"]
 | 
			
		||||
            else None,
 | 
			
		||||
            state=UpdateState(
 | 
			
		||||
                pts=dict["state"]["pts"],
 | 
			
		||||
                qts=dict["state"]["qts"],
 | 
			
		||||
                date=dict["state"]["date"],
 | 
			
		||||
                seq=dict["state"]["seq"],
 | 
			
		||||
                channels=[
 | 
			
		||||
                    ChannelState(id=channel["id"], pts=channel["pts"])
 | 
			
		||||
                    for channel in dict["state"]["channels"]
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
            if dict["state"]
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PtsInfo:
 | 
			
		||||
    __slots__ = ("pts", "pts_count", "entry")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
 | 
			
		|||
 | 
			
		||||
from ...tl import Request, abcs, functions, types
 | 
			
		||||
from ..chat import ChatHashCache
 | 
			
		||||
from ..session import ChannelState, UpdateState
 | 
			
		||||
from .adaptor import adapt, pts_info_from_update
 | 
			
		||||
from .defs import (
 | 
			
		||||
    BOT_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
| 
						 | 
				
			
			@ -17,13 +18,11 @@ from .defs import (
 | 
			
		|||
    NO_UPDATES_TIMEOUT,
 | 
			
		||||
    POSSIBLE_GAP_TIMEOUT,
 | 
			
		||||
    USER_CHANNEL_DIFF_LIMIT,
 | 
			
		||||
    ChannelState,
 | 
			
		||||
    Entry,
 | 
			
		||||
    Gap,
 | 
			
		||||
    PossibleGap,
 | 
			
		||||
    PrematureEndReason,
 | 
			
		||||
    State,
 | 
			
		||||
    UpdateState,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										159
									
								
								client/src/telethon/_impl/session/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								client/src/telethon/_impl/session/session.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,159 @@
 | 
			
		|||
import base64
 | 
			
		||||
from typing import Any, Dict, List, Optional, Self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCenter:
 | 
			
		||||
    __slots__ = ("id", "addr", "auth")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.addr = addr
 | 
			
		||||
        self.auth = auth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class User:
 | 
			
		||||
    __slots__ = ("id", "dc", "bot")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, dc: int, bot: bool) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.dc = dc
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChannelState:
 | 
			
		||||
    __slots__ = ("id", "pts")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, id: int, pts: int) -> None:
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.pts = pts
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpdateState:
 | 
			
		||||
    __slots__ = (
 | 
			
		||||
        "pts",
 | 
			
		||||
        "qts",
 | 
			
		||||
        "date",
 | 
			
		||||
        "seq",
 | 
			
		||||
        "channels",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        pts: int,
 | 
			
		||||
        qts: int,
 | 
			
		||||
        date: int,
 | 
			
		||||
        seq: int,
 | 
			
		||||
        channels: List[ChannelState],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.pts = pts
 | 
			
		||||
        self.qts = qts
 | 
			
		||||
        self.date = date
 | 
			
		||||
        self.seq = seq
 | 
			
		||||
        self.channels = channels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Session:
 | 
			
		||||
    """
 | 
			
		||||
    A Telethon session.
 | 
			
		||||
 | 
			
		||||
    A `Session` instance contains the required information to login into your
 | 
			
		||||
    Telegram account. **Never** give the saved session file to anyone else or
 | 
			
		||||
    make it public.
 | 
			
		||||
 | 
			
		||||
    Leaking the session file will grant a bad actor complete access to your
 | 
			
		||||
    account, including private conversations, groups you're part of and list
 | 
			
		||||
    of contacts (though not secret chats).
 | 
			
		||||
 | 
			
		||||
    If you think the session has been compromised, immediately terminate all
 | 
			
		||||
    sessions through an official Telegram client to revoke the authorization.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    VERSION = 1
 | 
			
		||||
 | 
			
		||||
    __slots__ = ("dcs", "user", "state")
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        dcs: Optional[List[DataCenter]] = None,
 | 
			
		||||
        user: Optional[User] = None,
 | 
			
		||||
        state: Optional[UpdateState] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.dcs = dcs or []
 | 
			
		||||
        self.user = user
 | 
			
		||||
        self.state = state
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        return {
 | 
			
		||||
            "v": self.VERSION,
 | 
			
		||||
            "dcs": [
 | 
			
		||||
                {
 | 
			
		||||
                    "id": dc.id,
 | 
			
		||||
                    "addr": dc.addr,
 | 
			
		||||
                    "auth": base64.b64encode(dc.auth).decode("ascii")
 | 
			
		||||
                    if dc.auth
 | 
			
		||||
                    else None,
 | 
			
		||||
                }
 | 
			
		||||
                for dc in self.dcs
 | 
			
		||||
            ],
 | 
			
		||||
            "user": {
 | 
			
		||||
                "id": self.user.id,
 | 
			
		||||
                "dc": self.user.dc,
 | 
			
		||||
                "bot": self.user.bot,
 | 
			
		||||
            }
 | 
			
		||||
            if self.user
 | 
			
		||||
            else None,
 | 
			
		||||
            "state": {
 | 
			
		||||
                "pts": self.state.pts,
 | 
			
		||||
                "qts": self.state.qts,
 | 
			
		||||
                "date": self.state.date,
 | 
			
		||||
                "seq": self.state.seq,
 | 
			
		||||
                "channels": [
 | 
			
		||||
                    {"id": channel.id, "pts": channel.pts}
 | 
			
		||||
                    for channel in self.state.channels
 | 
			
		||||
                ],
 | 
			
		||||
            }
 | 
			
		||||
            if self.state
 | 
			
		||||
            else None,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, dict: Dict[str, Any]) -> Self:
 | 
			
		||||
        version = dict["v"]
 | 
			
		||||
        if version != cls.VERSION:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"cannot parse session format version {version} (expected {cls.VERSION})"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return cls(
 | 
			
		||||
            dcs=[
 | 
			
		||||
                DataCenter(
 | 
			
		||||
                    id=dc["id"],
 | 
			
		||||
                    addr=dc["addr"],
 | 
			
		||||
                    auth=base64.b64decode(dc["auth"])
 | 
			
		||||
                    if dc["auth"] is not None
 | 
			
		||||
                    else None,
 | 
			
		||||
                )
 | 
			
		||||
                for dc in dict["dcs"]
 | 
			
		||||
            ],
 | 
			
		||||
            user=User(
 | 
			
		||||
                id=dict["user"]["id"],
 | 
			
		||||
                dc=dict["user"]["dc"],
 | 
			
		||||
                bot=dict["user"]["bot"],
 | 
			
		||||
            )
 | 
			
		||||
            if dict["user"]
 | 
			
		||||
            else None,
 | 
			
		||||
            state=UpdateState(
 | 
			
		||||
                pts=dict["state"]["pts"],
 | 
			
		||||
                qts=dict["state"]["qts"],
 | 
			
		||||
                date=dict["state"]["date"],
 | 
			
		||||
                seq=dict["state"]["seq"],
 | 
			
		||||
                channels=[
 | 
			
		||||
                    ChannelState(id=channel["id"], pts=channel["pts"])
 | 
			
		||||
                    for channel in dict["state"]["channels"]
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
            if dict["state"]
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										15
									
								
								client/src/telethon/_impl/session/storage/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								client/src/telethon/_impl/session/storage/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,15 @@
 | 
			
		|||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from .memory import MemorySession
 | 
			
		||||
from .storage import Storage
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from .sqlite import SqliteSession
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
 | 
			
		||||
    class SqliteSession(Storage):  # type: ignore [no-redef]
 | 
			
		||||
        def __init__(self, *args: Any, **kwargs: Any) -> None:
 | 
			
		||||
            raise e from None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["MemorySession", "Storage", "SqliteSession"]
 | 
			
		||||
							
								
								
									
										28
									
								
								client/src/telethon/_impl/session/storage/memory.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								client/src/telethon/_impl/session/storage/memory.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ..session import Session
 | 
			
		||||
from .storage import Storage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MemorySession(Storage):
 | 
			
		||||
    """
 | 
			
		||||
    Session storage without persistence.
 | 
			
		||||
 | 
			
		||||
    This is the simplest storage and is the one used by default.
 | 
			
		||||
 | 
			
		||||
    Session data is only kept in memory and is not persisted to disk.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    __slots__ = ("session",)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, session: Optional[Session] = None):
 | 
			
		||||
        self.session = session
 | 
			
		||||
 | 
			
		||||
    async def load(self) -> Optional[Session]:
 | 
			
		||||
        return self.session
 | 
			
		||||
 | 
			
		||||
    async def save(self, session: Session) -> None:
 | 
			
		||||
        self.session = session
 | 
			
		||||
 | 
			
		||||
    async def delete(self) -> None:
 | 
			
		||||
        self.session = None
 | 
			
		||||
							
								
								
									
										202
									
								
								client/src/telethon/_impl/session/storage/sqlite.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								client/src/telethon/_impl/session/storage/sqlite.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,202 @@
 | 
			
		|||
import sqlite3
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
from ..session import ChannelState, DataCenter, Session, UpdateState, User
 | 
			
		||||
from .storage import Storage
 | 
			
		||||
 | 
			
		||||
EXTENSION = ".session"
 | 
			
		||||
CURRENT_VERSION = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SqliteSession(Storage):
 | 
			
		||||
    """
 | 
			
		||||
    Session storage backed by SQLite.
 | 
			
		||||
 | 
			
		||||
    SQLite is a reliable way to persist data to disk and offers file locking.
 | 
			
		||||
 | 
			
		||||
    Paths without extension will have '.session' appended to them.
 | 
			
		||||
    This is by convention, and to make it harder to commit session files to
 | 
			
		||||
    an VCS by accident (adding `*.session` to `.gitignore` will catch them).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, file: Union[str, Path]):
 | 
			
		||||
        path = Path(file)
 | 
			
		||||
        if not path.suffix:
 | 
			
		||||
            path = path.with_suffix(EXTENSION)
 | 
			
		||||
 | 
			
		||||
        self._path = path
 | 
			
		||||
        self._conn: Optional[sqlite3.Connection] = None
 | 
			
		||||
 | 
			
		||||
    async def load(self) -> Optional[Session]:
 | 
			
		||||
        conn = self._current_conn()
 | 
			
		||||
 | 
			
		||||
        c = conn.cursor()
 | 
			
		||||
        with conn:
 | 
			
		||||
            version = self._get_or_init_version(c)
 | 
			
		||||
            if version < CURRENT_VERSION:
 | 
			
		||||
                if version == 7:
 | 
			
		||||
                    session = self._load_v7(c)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
                self._reset(c)
 | 
			
		||||
                self._get_or_init_version(c)
 | 
			
		||||
                self._save_v10(c, session)
 | 
			
		||||
 | 
			
		||||
            return self._load_v10(c)
 | 
			
		||||
 | 
			
		||||
    async def save(self, session: Session) -> None:
 | 
			
		||||
        conn = self._current_conn()
 | 
			
		||||
        with conn:
 | 
			
		||||
            self._save_v10(conn.cursor(), session)
 | 
			
		||||
        conn.close()
 | 
			
		||||
        self._conn = None
 | 
			
		||||
 | 
			
		||||
    async def delete(self) -> None:
 | 
			
		||||
        if self._conn:
 | 
			
		||||
            self._conn.close()
 | 
			
		||||
            self._conn = None
 | 
			
		||||
        self._path.unlink()
 | 
			
		||||
 | 
			
		||||
    def _current_conn(self) -> sqlite3.Connection:
 | 
			
		||||
        if self._conn is None:
 | 
			
		||||
            self._conn = sqlite3.connect(self._path)
 | 
			
		||||
 | 
			
		||||
        return self._conn
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _load_v7(c: sqlite3.Cursor) -> Session:
 | 
			
		||||
        # Session v7 format from telethon v1
 | 
			
		||||
        c.execute("select dc_id, server_address, port, auth_key from sessions")
 | 
			
		||||
        sessions = c.fetchall()
 | 
			
		||||
        c.execute("select pts, qts, date, seq from update_state where id = 0")
 | 
			
		||||
        state = c.fetchone()
 | 
			
		||||
        c.execute("select id, pts from update_state where id != 0")
 | 
			
		||||
        channelstate = c.fetchall()
 | 
			
		||||
 | 
			
		||||
        return Session(
 | 
			
		||||
            dcs=[
 | 
			
		||||
                DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
 | 
			
		||||
                for (id, ip, port, auth) in sessions
 | 
			
		||||
            ],
 | 
			
		||||
            user=None,
 | 
			
		||||
            state=UpdateState(
 | 
			
		||||
                pts=state[0],
 | 
			
		||||
                qts=state[1],
 | 
			
		||||
                date=state[2],
 | 
			
		||||
                seq=state[3],
 | 
			
		||||
                channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _load_v10(c: sqlite3.Cursor) -> Session:
 | 
			
		||||
        c.execute("select * from datacenter")
 | 
			
		||||
        datacenter = c.fetchall()
 | 
			
		||||
        c.execute("select * from user")
 | 
			
		||||
        user = c.fetchone()
 | 
			
		||||
        c.execute("select * from state")
 | 
			
		||||
        state = c.fetchone()
 | 
			
		||||
        c.execute("select * from channelstate")
 | 
			
		||||
        channelstate = c.fetchall()
 | 
			
		||||
 | 
			
		||||
        return Session(
 | 
			
		||||
            dcs=[
 | 
			
		||||
                DataCenter(id=id, addr=addr, auth=auth)
 | 
			
		||||
                for (id, addr, auth) in datacenter
 | 
			
		||||
            ],
 | 
			
		||||
            user=User(id=user[0], dc=user[1], bot=bool(user[2])) if user else None,
 | 
			
		||||
            state=UpdateState(
 | 
			
		||||
                pts=state[0],
 | 
			
		||||
                qts=state[1],
 | 
			
		||||
                date=state[2],
 | 
			
		||||
                seq=state[3],
 | 
			
		||||
                channels=[ChannelState(id=id, pts=pts) for id, pts in channelstate],
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _save_v10(c: sqlite3.Cursor, session: Session) -> None:
 | 
			
		||||
        c.execute("delete from datacenter")
 | 
			
		||||
        c.execute("delete from user")
 | 
			
		||||
        c.execute("delete from state")
 | 
			
		||||
        c.execute("delete from channelstate")
 | 
			
		||||
        c.executemany(
 | 
			
		||||
            "insert into datacenter values (?, ?, ?)",
 | 
			
		||||
            [(dc.id, dc.addr, dc.auth) for dc in session.dcs],
 | 
			
		||||
        )
 | 
			
		||||
        if user := session.user:
 | 
			
		||||
            c.execute(
 | 
			
		||||
                "insert into user values (?, ?, ?)", (user.id, user.dc, int(user.bot))
 | 
			
		||||
            )
 | 
			
		||||
        if state := session.state:
 | 
			
		||||
            c.execute(
 | 
			
		||||
                "insert into state values (?, ?, ?, ?)",
 | 
			
		||||
                (state.pts, state.qts, state.date, state.seq),
 | 
			
		||||
            )
 | 
			
		||||
            c.executemany(
 | 
			
		||||
                "insert into channelstate values (?, ?)",
 | 
			
		||||
                [(channel.id, channel.pts) for channel in state.channels],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _reset(c: sqlite3.Cursor) -> None:
 | 
			
		||||
        safe_chars = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 | 
			
		||||
 | 
			
		||||
        c.execute("select name from sqlite_master where type='table'")
 | 
			
		||||
        for (name,) in c.fetchall():
 | 
			
		||||
            # Can't format arguments for table names. Regardless, it shouldn't
 | 
			
		||||
            # be an SQL-injection because names come from `sqlite_master`.
 | 
			
		||||
            # Just to be on the safe-side, check for r'\w+' nevertheless,
 | 
			
		||||
            # avoiding referencing globals which could've been monkey-patched.
 | 
			
		||||
            for char in name:
 | 
			
		||||
                if char not in safe_chars or name.__len__() > 20:
 | 
			
		||||
                    raise ValueError(f"potentially unsafe table name: {name}")
 | 
			
		||||
 | 
			
		||||
            c.execute(f"drop table {name}")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_or_init_version(c: sqlite3.Cursor) -> int:
 | 
			
		||||
        c.execute(
 | 
			
		||||
            "select name from sqlite_master where type='table' and name='version'"
 | 
			
		||||
        )
 | 
			
		||||
        if c.fetchone():
 | 
			
		||||
            c.execute("select version from version")
 | 
			
		||||
            res = c.fetchone()[0]
 | 
			
		||||
            assert isinstance(res, int)
 | 
			
		||||
            return res
 | 
			
		||||
        else:
 | 
			
		||||
            SqliteSession._create_tables(c)
 | 
			
		||||
            c.execute("insert into version values (?)", (CURRENT_VERSION,))
 | 
			
		||||
            return CURRENT_VERSION
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _create_tables(c: sqlite3.Cursor) -> None:
 | 
			
		||||
        c.executescript(
 | 
			
		||||
            """
 | 
			
		||||
            create table version (
 | 
			
		||||
                version integer primary key
 | 
			
		||||
            );
 | 
			
		||||
            create table datacenter(
 | 
			
		||||
                id integer primary key,
 | 
			
		||||
                addr text not null,
 | 
			
		||||
                auth blob
 | 
			
		||||
            );
 | 
			
		||||
            create table user(
 | 
			
		||||
                id integer primary key,
 | 
			
		||||
                dc integer not null,
 | 
			
		||||
                bot integer not null
 | 
			
		||||
            );
 | 
			
		||||
            create table state(
 | 
			
		||||
                pts integer not null,
 | 
			
		||||
                qts integer not null,
 | 
			
		||||
                date integer not null,
 | 
			
		||||
                seq integer not null
 | 
			
		||||
            );
 | 
			
		||||
            create table channelstate(
 | 
			
		||||
                id integer primary key,
 | 
			
		||||
                pts integer not null
 | 
			
		||||
            );
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										30
									
								
								client/src/telethon/_impl/session/storage/storage.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								client/src/telethon/_impl/session/storage/storage.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
import abc
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ..session import Session
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Storage(abc.ABC):
 | 
			
		||||
    @abc.abstractmethod
 | 
			
		||||
    async def load(self) -> Optional[Session]:
 | 
			
		||||
        """
 | 
			
		||||
        Load the `Session` instance, if any.
 | 
			
		||||
 | 
			
		||||
        This method is called by the library prior to `connect`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @abc.abstractmethod
 | 
			
		||||
    async def save(self, session: Session) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Save the `Session` instance to persistent storage.
 | 
			
		||||
 | 
			
		||||
        This method is called by the library post `disconnect`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @abc.abstractmethod
 | 
			
		||||
    async def delete(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Delete the saved `Session`.
 | 
			
		||||
 | 
			
		||||
        This method is called by the library post `log_out`.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -13,13 +13,7 @@ async def test_ping_pong() -> None:
 | 
			
		|||
    api_hash = os.getenv("TG_HASH")
 | 
			
		||||
    assert api_id and api_id.isdigit()
 | 
			
		||||
    assert api_hash
 | 
			
		||||
    client = Client(
 | 
			
		||||
        Config(
 | 
			
		||||
            session=Session(),
 | 
			
		||||
            api_id=int(api_id),
 | 
			
		||||
            api_hash=api_hash,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    client = Client(None, int(api_id), api_hash)
 | 
			
		||||
    assert not client.connected
 | 
			
		||||
    await client.connect()
 | 
			
		||||
    assert client.connected
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,7 +81,6 @@ def test_unpack_two_at_once() -> None:
 | 
			
		|||
    with raises(ValueError) as e:
 | 
			
		||||
        transport.unpack(input, output)
 | 
			
		||||
    e.match("bad seq")
 | 
			
		||||
    assert output == expected_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unpack_twice() -> None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user