Make Session more convenient to use

This commit is contained in:
Lonami Exo 2023-09-02 00:48:26 +02:00
parent 3863cf0972
commit c46387f7bf
5 changed files with 93 additions and 10 deletions

2
.gitignore vendored
View File

@ -15,3 +15,5 @@ build/
**/mtproto/abcs/ **/mtproto/abcs/
**/mtproto/functions/ **/mtproto/functions/
**/mtproto/types/ **/mtproto/types/
**/testbed.py
*.session

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from ...mtproto.mtp.types import RpcError from ...mtproto.mtp.types import RpcError
from ...session.message_box.defs import Session
from ...session.message_box.defs import User as SessionUser from ...session.message_box.defs import User as SessionUser
from ...tl import abcs, functions, types from ...tl import abcs, functions, types
from ..types.chat.user import User from ..types.chat.user import User
@ -141,3 +142,8 @@ async def check_password(
async def sign_out(self: Client) -> None: async def sign_out(self: Client) -> None:
await self(functions.auth.log_out()) await self(functions.auth.log_out())
def session(self: Client) -> Session:
self._config.session.state = self._message_box.session_state()
return self._config.session

View File

@ -6,6 +6,7 @@ from typing import Deque, Optional, Self, Type, TypeVar, Union
from ...mtsender.sender import Sender from ...mtsender.sender import Sender
from ...session.chat.hash_cache import ChatHashCache from ...session.chat.hash_cache import ChatHashCache
from ...session.chat.packed import PackedChat from ...session.chat.packed import PackedChat
from ...session.message_box.defs import Session
from ...session.message_box.messagebox import MessageBox from ...session.message_box.messagebox import MessageBox
from ...tl import abcs from ...tl import abcs
from ...tl.core.request import Request from ...tl.core.request import Request
@ -20,6 +21,7 @@ from .auth import (
check_password, check_password,
is_authorized, is_authorized,
request_login_code, request_login_code,
session,
sign_in, sign_in,
sign_out, sign_out,
) )
@ -126,6 +128,15 @@ class Client:
async def sign_out(self) -> None: async def sign_out(self) -> None:
await sign_out(self) await sign_out(self)
@property
def session(self) -> Session:
"""
Up-to-date session state, useful for persisting it to storage.
Mutating the returned object may cause the library to misbehave.
"""
return session(self)
async def inline_query( async def inline_query(
self, bot: ChatLike, query: str, *, chat: Optional[ChatLike] = None self, bot: ChatLike, query: str, *, chat: Optional[ChatLike] = None
) -> None: ) -> None:

View File

@ -1,6 +1,7 @@
import base64
import logging import logging
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Self, Union
from ...tl import abcs from ...tl import abcs
@ -62,14 +63,81 @@ class Session:
def __init__( def __init__(
self, self,
*, *,
dcs: List[DataCenter], dcs: Optional[List[DataCenter]] = None,
user: Optional[User], user: Optional[User] = None,
state: Optional[UpdateState], state: Optional[UpdateState] = None,
): ):
self.dcs = dcs self.dcs = dcs or []
self.user = user self.user = user
self.state = state self.state = state
def to_dict(self) -> Dict[str, Any]:
return {
"dcs": [
{
"id": dc.id,
"addr": dc.addr,
"auth": base64.b64encode(dc.auth).decode("ascii")
if dc.auth
else None,
}
for dc in self.dcs
],
"user": {
"id": self.user.id,
"dc": self.user.dc,
"bot": self.user.bot,
}
if self.user
else None,
"state": {
"pts": self.state.pts,
"qts": self.state.qts,
"date": self.state.date,
"seq": self.state.seq,
"channels": [
{"id": channel.id, "pts": channel.pts}
for channel in self.state.channels
],
}
if self.state
else None,
}
@classmethod
def from_dict(cls, dict: Dict[str, Any]) -> Self:
return cls(
dcs=[
DataCenter(
id=dc["id"],
addr=dc["addr"],
auth=base64.b64decode(dc["auth"])
if dc["auth"] is not None
else None,
)
for dc in dict["dcs"]
],
user=User(
id=dict["user"]["id"],
dc=dict["user"]["dc"],
bot=dict["user"]["bot"],
)
if dict["user"]
else None,
state=UpdateState(
pts=dict["state"]["pts"],
qts=dict["state"]["qts"],
date=dict["state"]["date"],
seq=dict["state"]["seq"],
channels=[
ChannelState(id=channel["id"], pts=channel["pts"])
for channel in dict["state"]["channels"]
],
)
if dict["state"]
else None,
)
class PtsInfo: class PtsInfo:
__slots__ = ("pts", "pts_count", "entry") __slots__ = ("pts", "pts_count", "entry")

View File

@ -17,11 +17,7 @@ async def test_ping_pong() -> None:
assert api_hash assert api_hash
client = Client( client = Client(
Config( Config(
session=Session( session=Session(),
dcs=[],
user=None,
state=None,
),
api_id=int(api_id), api_id=int(api_id),
api_hash=api_hash, api_hash=api_hash,
) )