Use frozen dataclasses for session types

Now that 3.7 is the minimum version,
we can use dataclasses.
This commit is contained in:
Lonami Exo 2022-01-09 13:01:16 +01:00
parent 7ea30961ae
commit be6508dc5d
2 changed files with 53 additions and 94 deletions

View File

@ -1,4 +1,4 @@
from .types import DataCenter, ChannelState, SessionState, Entity
from .types import DataCenter, ChannelState, SessionState, EntityType, Entity
from abc import ABC, abstractmethod
from typing import List, Optional
@ -59,7 +59,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
async def get_entity(self, ty: Optional[EntityType], id: int) -> Optional[Entity]:
"""
Get the `Entity` with matching ``ty`` and ``id``.
@ -67,14 +67,14 @@ class Session(ABC):
``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
that ``id`` from within any ``ty`` in that group should be returned.
* ``'U'`` and ``'B'`` (user and bot).
* ``'G'`` (small group chat).
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel).
* `EntityType.USER` and `EntityType.BOT`.
* `EntityType.GROUP`.
* `EntityType.CHANNEL`, `EntityType.MEGAGROUP` and `EntityType.GIGAGROUP`.
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type.
You may use ``EntityType.canonical`` to find out the canonical type.
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
"""

View File

@ -1,6 +1,9 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from enum import IntEnum
@dataclass(frozen=True)
class DataCenter:
"""
Stores the information needed to connect to a datacenter.
@ -12,21 +15,14 @@ class DataCenter:
"""
__slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth')
def __init__(
self,
id: int,
ipv4: int,
ipv6: Optional[int],
port: int,
id: int
ipv4: int
ipv6: Optional[int]
port: int
auth: bytes
):
self.id = id
self.ipv4 = ipv4
self.ipv6 = ipv6
self.port = port
self.auth = auth
@dataclass(frozen=True)
class SessionState:
"""
Stores the information needed to fetch updates and about the current user.
@ -45,27 +41,17 @@ class SessionState:
"""
__slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id')
def __init__(
self,
user_id: int,
dc_id: int,
bot: bool,
pts: int,
qts: int,
date: int,
seq: int,
takeout_id: Optional[int],
):
self.user_id = user_id
self.dc_id = dc_id
self.bot = bot
self.pts = pts
self.qts = qts
self.date = date
self.seq = seq
self.takeout_id = takeout_id
user_id: int
dc_id: int
bot: bool
pts: int
qts: int
date: int
seq: int
takeout_id: Optional[int]
@dataclass(frozen=True)
class ChannelState:
"""
Stores the information needed to fetch updates from a channel.
@ -75,24 +61,13 @@ class ChannelState:
"""
__slots__ = ('channel_id', 'pts')
def __init__(
self,
channel_id: int,
channel_id: int
pts: int
):
self.channel_id = channel_id
self.pts = pts
class Entity:
class EntityType(IntEnum):
"""
Stores the information needed to use a certain user, chat or channel with the API.
* ty: 8-bit number indicating the type of the entity.
* id: 64-bit number uniquely identifying the entity among those of the same type.
* access_hash: 64-bit number needed to use this entity with the API.
You can rely on the ``ty`` value to be equal to the ASCII character one of:
You can rely on the type value to be equal to the ASCII character one of:
* 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``.
* 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``.
@ -101,8 +76,6 @@ class Entity:
* 'M' (77): this entity belongs to a megagroup :tl:`Channel`.
* 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`.
"""
__slots__ = ('ty', 'id', 'access_hash')
USER = ord('U')
BOT = ord('B')
GROUP = ord('G')
@ -110,48 +83,34 @@ class Entity:
MEGAGROUP = ord('M')
GIGAGROUP = ord('E')
def __init__(
self,
ty: int,
id: int,
def canonical(self):
"""
Return the canonical version of this type.
"""
return _canon_entity_types[self]
_canon_entity_types = {
EntityType.USER: EntityType.USER,
EntityType.BOT: EntityType.USER,
EntityType.GROUP: EntityType.GROUP,
EntityType.CHANNEL: EntityType.CHANNEL,
EntityType.MEGAGROUP: EntityType.CHANNEL,
EntityType.GIGAGROUP: EntityType.CHANNEL,
}
@dataclass(frozen=True)
class Entity:
"""
Stores the information needed to use a certain user, chat or channel with the API.
* ty: 8-bit number indicating the type of the entity.
* id: 64-bit number uniquely identifying the entity among those of the same type.
* access_hash: 64-bit number needed to use this entity with the API.
"""
__slots__ = ('ty', 'id', 'access_hash')
ty: EntityType
id: int
access_hash: int
):
self.ty = ty
self.id = id
self.access_hash = access_hash
def canonical_entity_type(ty: int, *, _mapping={
Entity.USER: Entity.USER,
Entity.BOT: Entity.USER,
Entity.GROUP: Entity.GROUP,
Entity.CHANNEL: Entity.CHANNEL,
Entity.MEGAGROUP: Entity.CHANNEL,
Entity.GIGAGROUP: Entity.CHANNEL,
}) -> int:
"""
Return the canonical version of an entity type.
"""
try:
return _mapping[ty]
except KeyError:
ty = chr(ty) if isinstance(ty, int) else ty
raise ValueError(f'entity type {ty!r} is not valid')
def get_entity_type_group(ty: int, *, _mapping={
Entity.USER: (Entity.USER, Entity.BOT),
Entity.BOT: (Entity.USER, Entity.BOT),
Entity.GROUP: (Entity.GROUP,),
Entity.CHANNEL: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.MEGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.GIGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
}) -> Tuple[int]:
"""
Return the group where an entity type belongs to.
"""
try:
return _mapping[ty]
except KeyError:
ty = chr(ty) if isinstance(ty, int) else ty
raise ValueError(f'entity type {ty!r} is not valid')