From be6508dc5d0d28d75a550cc910805af3e92d74a0 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sun, 9 Jan 2022 13:01:16 +0100 Subject: [PATCH] Use frozen dataclasses for session types Now that 3.7 is the minimum version, we can use dataclasses. --- telethon/_sessions/abstract.py | 12 +-- telethon/_sessions/types.py | 135 ++++++++++++--------------------- 2 files changed, 53 insertions(+), 94 deletions(-) diff --git a/telethon/_sessions/abstract.py b/telethon/_sessions/abstract.py index 2b28ae76..cdb747a4 100644 --- a/telethon/_sessions/abstract.py +++ b/telethon/_sessions/abstract.py @@ -1,4 +1,4 @@ -from .types import DataCenter, ChannelState, SessionState, Entity +from .types import DataCenter, ChannelState, SessionState, EntityType, Entity from abc import ABC, abstractmethod from typing import List, Optional @@ -59,7 +59,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]: + async def get_entity(self, ty: Optional[EntityType], id: int) -> Optional[Entity]: """ Get the `Entity` with matching ``ty`` and ``id``. @@ -67,14 +67,14 @@ class Session(ABC): ``ty`` and ``id``, if the ``ty`` is in a given group, a matching ``access_hash`` with that ``id`` from within any ``ty`` in that group should be returned. - * ``'U'`` and ``'B'`` (user and bot). - * ``'G'`` (small group chat). - * ``'C'``, ``'M'`` and ``'E'`` (broadcast channel, megagroup channel, and gigagroup channel). + * `EntityType.USER` and `EntityType.BOT`. + * `EntityType.GROUP`. + * `EntityType.CHANNEL`, `EntityType.MEGAGROUP` and `EntityType.GIGAGROUP`. For example, if a ``ty`` representing a bot is stored but the asking ``ty`` is a user, the corresponding ``access_hash`` should still be returned. - You may use `types.canonical_entity_type` to find out the canonical type. + You may use ``EntityType.canonical`` to find out the canonical type. A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID". """ diff --git a/telethon/_sessions/types.py b/telethon/_sessions/types.py index 51d4ecb5..a9738709 100644 --- a/telethon/_sessions/types.py +++ b/telethon/_sessions/types.py @@ -1,6 +1,9 @@ from typing import Optional, Tuple +from dataclasses import dataclass +from enum import IntEnum +@dataclass(frozen=True) class DataCenter: """ Stores the information needed to connect to a datacenter. @@ -12,21 +15,14 @@ class DataCenter: """ __slots__ = ('id', 'ipv4', 'ipv6', 'port', 'auth') - def __init__( - self, - id: int, - ipv4: int, - ipv6: Optional[int], - port: int, - auth: bytes - ): - self.id = id - self.ipv4 = ipv4 - self.ipv6 = ipv6 - self.port = port - self.auth = auth + id: int + ipv4: int + ipv6: Optional[int] + port: int + auth: bytes +@dataclass(frozen=True) class SessionState: """ Stores the information needed to fetch updates and about the current user. @@ -45,27 +41,17 @@ class SessionState: """ __slots__ = ('user_id', 'dc_id', 'bot', 'pts', 'qts', 'date', 'seq', 'takeout_id') - def __init__( - self, - user_id: int, - dc_id: int, - bot: bool, - pts: int, - qts: int, - date: int, - seq: int, - takeout_id: Optional[int], - ): - self.user_id = user_id - self.dc_id = dc_id - self.bot = bot - self.pts = pts - self.qts = qts - self.date = date - self.seq = seq - self.takeout_id = takeout_id + user_id: int + dc_id: int + bot: bool + pts: int + qts: int + date: int + seq: int + takeout_id: Optional[int] +@dataclass(frozen=True) class ChannelState: """ Stores the information needed to fetch updates from a channel. @@ -75,24 +61,13 @@ class ChannelState: """ __slots__ = ('channel_id', 'pts') - def __init__( - self, - channel_id: int, - pts: int - ): - self.channel_id = channel_id - self.pts = pts + channel_id: int + pts: int -class Entity: +class EntityType(IntEnum): """ - Stores the information needed to use a certain user, chat or channel with the API. - - * ty: 8-bit number indicating the type of the entity. - * id: 64-bit number uniquely identifying the entity among those of the same type. - * access_hash: 64-bit number needed to use this entity with the API. - - You can rely on the ``ty`` value to be equal to the ASCII character one of: + You can rely on the type value to be equal to the ASCII character one of: * 'U' (85): this entity belongs to a :tl:`User` who is not a ``bot``. * 'B' (66): this entity belongs to a :tl:`User` who is a ``bot``. @@ -101,8 +76,6 @@ class Entity: * 'M' (77): this entity belongs to a megagroup :tl:`Channel`. * 'E' (69): this entity belongs to an "enormous" "gigagroup" :tl:`Channel`. """ - __slots__ = ('ty', 'id', 'access_hash') - USER = ord('U') BOT = ord('B') GROUP = ord('G') @@ -110,48 +83,34 @@ class Entity: MEGAGROUP = ord('M') GIGAGROUP = ord('E') - def __init__( - self, - ty: int, - id: int, - access_hash: int - ): - self.ty = ty - self.id = id - self.access_hash = access_hash + def canonical(self): + """ + Return the canonical version of this type. + """ + return _canon_entity_types[self] -def canonical_entity_type(ty: int, *, _mapping={ - Entity.USER: Entity.USER, - Entity.BOT: Entity.USER, - Entity.GROUP: Entity.GROUP, - Entity.CHANNEL: Entity.CHANNEL, - Entity.MEGAGROUP: Entity.CHANNEL, - Entity.GIGAGROUP: Entity.CHANNEL, -}) -> int: - """ - Return the canonical version of an entity type. - """ - try: - return _mapping[ty] - except KeyError: - ty = chr(ty) if isinstance(ty, int) else ty - raise ValueError(f'entity type {ty!r} is not valid') +_canon_entity_types = { + EntityType.USER: EntityType.USER, + EntityType.BOT: EntityType.USER, + EntityType.GROUP: EntityType.GROUP, + EntityType.CHANNEL: EntityType.CHANNEL, + EntityType.MEGAGROUP: EntityType.CHANNEL, + EntityType.GIGAGROUP: EntityType.CHANNEL, +} -def get_entity_type_group(ty: int, *, _mapping={ - Entity.USER: (Entity.USER, Entity.BOT), - Entity.BOT: (Entity.USER, Entity.BOT), - Entity.GROUP: (Entity.GROUP,), - Entity.CHANNEL: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP), - Entity.MEGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP), - Entity.GIGAGROUP: (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP), -}) -> Tuple[int]: +@dataclass(frozen=True) +class Entity: """ - Return the group where an entity type belongs to. + Stores the information needed to use a certain user, chat or channel with the API. + + * ty: 8-bit number indicating the type of the entity. + * id: 64-bit number uniquely identifying the entity among those of the same type. + * access_hash: 64-bit number needed to use this entity with the API. """ - try: - return _mapping[ty] - except KeyError: - ty = chr(ty) if isinstance(ty, int) else ty - raise ValueError(f'entity type {ty!r} is not valid') + __slots__ = ('ty', 'id', 'access_hash') + + ty: EntityType + id: int + access_hash: int