Use frozen dataclasses for session types

Now that 3.7 is the minimum version,
we can use dataclasses.
This commit is contained in:
Lonami Exo 2022-01-09 13:01:16 +01:00
parent 7ea30961ae
commit be6508dc5d
2 changed files with 53 additions and 94 deletions

View File

@ -1,4 +1,4 @@
from .types import DataCenter, ChannelState, SessionState, Entity from .types import DataCenter, ChannelState, SessionState, EntityType, Entity
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
@ -59,7 +59,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]: async def get_entity(self, ty: Optional[EntityType], id: int) -> Optional[Entity]:
""" """
Get the `Entity` with matching ``ty`` and ``id``. Get the `Entity` with matching ``ty`` and ``id``.
@ -67,14 +67,14 @@ class Session(ABC):
``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with ``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with
that ``id`` from within any ``ty`` in that group should be returned. that ``id`` from within any ``ty`` in that group should be returned.
* ``'U'`` and ``'B'`` (user and bot). * `EntityType.USER` and `EntityType.BOT`.
* ``'G'`` (small group chat). * `EntityType.GROUP`.
* ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel). * `EntityType.CHANNEL`, `EntityType.MEGAGROUP` and `EntityType.GIGAGROUP`.
For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user, For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user,
the corresponding ``access_hash`` should still be returned. the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type. You may use ``EntityType.canonical`` to find out the canonical type.
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID". A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
""" """

View File

@ -1,6 +1,9 @@
from typing import Optional, Tuple from typing import Optional, Tuple
from dataclasses import dataclass
from enum import IntEnum
@dataclass(frozen=True)
class DataCenter: class DataCenter:
""" """
Stores the information needed to connect to a datacenter. Stores the information needed to connect to a datacenter.
@ -12,21 +15,14 @@ class DataCenter:
""" """
__slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth') __slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth')
def __init__( id: int
self, ipv4: int
id: int, ipv6: Optional[int]
ipv4: int, port: int
ipv6: Optional[int], auth: bytes
port: int,
auth: bytes
):
self.id = id
self.ipv4 = ipv4
self.ipv6 = ipv6
self.port = port
self.auth = auth
@dataclass(frozen=True)
class SessionState: class SessionState:
""" """
Stores the information needed to fetch updates and about the current user. Stores the information needed to fetch updates and about the current user.
@ -45,27 +41,17 @@ class SessionState:
""" """
__slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id') __slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id')
def __init__( user_id: int
self, dc_id: int
user_id: int, bot: bool
dc_id: int, pts: int
bot: bool, qts: int
pts: int, date: int
qts: int, seq: int
date: int, takeout_id: Optional[int]
seq: int,
takeout_id: Optional[int],
):
self.user_id = user_id
self.dc_id = dc_id
self.bot = bot
self.pts = pts
self.qts = qts
self.date = date
self.seq = seq
self.takeout_id = takeout_id
@dataclass(frozen=True)
class ChannelState: class ChannelState:
""" """
Stores the information needed to fetch updates from a channel. Stores the information needed to fetch updates from a channel.
@ -75,24 +61,13 @@ class ChannelState:
""" """
__slots__ = ('channel_id', 'pts') __slots__ = ('channel_id', 'pts')
def __init__( channel_id: int
self, pts: int
channel_id: int,
pts: int
):
self.channel_id = channel_id
self.pts = pts
class Entity: class EntityType(IntEnum):
""" """
Stores the information needed to use a certain user, chat or channel with the API. You can rely on the type value to be equal to the ASCII character one of:
* ty: 8-bit number indicating the type of the entity.
* id: 64-bit number uniquely identifying the entity among those of the same type.
* access_hash: 64-bit number needed to use this entity with the API.
You can rely on the ``ty`` value to be equal to the ASCII character one of:
* 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``. * 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``.
* 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``. * 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``.
@ -101,8 +76,6 @@ class Entity:
* 'M' (77): this entity belongs to a megagroup :tl:`Channel`. * 'M' (77): this entity belongs to a megagroup :tl:`Channel`.
* 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`. * 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`.
""" """
__slots__ = ('ty', 'id', 'access_hash')
USER = ord('U') USER = ord('U')
BOT = ord('B') BOT = ord('B')
GROUP = ord('G') GROUP = ord('G')
@ -110,48 +83,34 @@ class Entity:
MEGAGROUP = ord('M') MEGAGROUP = ord('M')
GIGAGROUP = ord('E') GIGAGROUP = ord('E')
def __init__( def canonical(self):
self, """
ty: int, Return the canonical version of this type.
id: int, """
access_hash: int return _canon_entity_types[self]
):
self.ty = ty
self.id = id
self.access_hash = access_hash
def canonical_entity_type(ty: int, *, _mapping={ _canon_entity_types = {
Entity.USER: Entity.USER, EntityType.USER: EntityType.USER,
Entity.BOT: Entity.USER, EntityType.BOT: EntityType.USER,
Entity.GROUP: Entity.GROUP, EntityType.GROUP: EntityType.GROUP,
Entity.CHANNEL: Entity.CHANNEL, EntityType.CHANNEL: EntityType.CHANNEL,
Entity.MEGAGROUP: Entity.CHANNEL, EntityType.MEGAGROUP: EntityType.CHANNEL,
Entity.GIGAGROUP: Entity.CHANNEL, EntityType.GIGAGROUP: EntityType.CHANNEL,
}) -> int: }
"""
Return the canonical version of an entity type.
"""
try:
return _mapping[ty]
except KeyError:
ty = chr(ty) if isinstance(ty, int) else ty
raise ValueError(f'entity type {ty!r} is not valid')
def get_entity_type_group(ty: int, *, _mapping={ @dataclass(frozen=True)
Entity.USER: (Entity.USER, Entity.BOT), class Entity:
Entity.BOT: (Entity.USER, Entity.BOT),
Entity.GROUP: (Entity.GROUP,),
Entity.CHANNEL: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.MEGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
Entity.GIGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP),
}) -> Tuple[int]:
""" """
Return the group where an entity type belongs to. Stores the information needed to use a certain user, chat or channel with the API.
* ty: 8-bit number indicating the type of the entity.
* id: 64-bit number uniquely identifying the entity among those of the same type.
* access_hash: 64-bit number needed to use this entity with the API.
""" """
try: __slots__ = ('ty', 'id', 'access_hash')
return _mapping[ty]
except KeyError: ty: EntityType
ty = chr(ty) if isinstance(ty, int) else ty id: int
raise ValueError(f'entity type {ty!r} is not valid') access_hash: int