Add type-hinting (ChatGetter, Events) and fix bugs

This commit is contained in:
M. Hosseyn Najafi 2020-12-14 19:54:47 +03:30
parent b754d4fbd5
commit 2073da83d4
8 changed files with 74 additions and 35 deletions

View File

@ -1,14 +1,17 @@
import abc
import asyncio
import warnings
from typing import Optional, Sequence, Callable
from typing import Optional, Sequence, Callable, TYPE_CHECKING
from .. import utils, TelegramClient, hints
from .. import utils, hints
from ..tl import TLObject, types
from ..tl.custom.chatgetter import ChatGetter
if TYPE_CHECKING:
from .. import TelegramClient
async def _into_id_set(client, chats):
async def _into_id_set(client: 'TelegramClient', chats):
"""Helper util to turn the input chat or chats into a set of IDs."""
if chats is None:
return None
@ -66,8 +69,11 @@ class EventBuilder(abc.ABC):
async def handler(event):
pass # code here
"""
def __init__(self, chats: Optional[Sequence[hints.Entity]] = None, *,
blacklist_chats: bool = False, func: Optional[Callable[['EventCommon'], None]] = None):
def __init__(self,
chats: Optional[Sequence[hints.Entity]] = None,
*,
blacklist_chats: bool = False,
func: Optional[Callable[['EventCommon'], None]] = None):
self.chats = chats
self.blacklist_chats = bool(blacklist_chats)
self.resolved = False
@ -76,7 +82,7 @@ class EventBuilder(abc.ABC):
@classmethod
@abc.abstractmethod
def build(cls, update, others=None, self_id=None):
def build(cls, update: types.TypeUpdate, others=None, self_id=None):
"""
Builds an event for the given update if possible, or returns None.
@ -88,7 +94,7 @@ class EventBuilder(abc.ABC):
"""
# TODO So many parameters specific to only some update types seems dirty
async def resolve(self, client):
async def resolve(self, client: 'TelegramClient'):
"""Helper method to allow event builders to be resolved before usage"""
if self.resolved:
return
@ -101,7 +107,7 @@ class EventBuilder(abc.ABC):
await self._resolve(client)
self.resolved = True
async def _resolve(self, client):
async def _resolve(self, client: 'TelegramClient'):
self.chats = await _into_id_set(client, self.chats)
def filter(self, event: 'EventCommon'):
@ -142,14 +148,17 @@ class EventCommon(ChatGetter, abc.ABC):
"""
_event_name = 'Event'
def __init__(self, chat_peer=None, msg_id=None, broadcast=None):
def __init__(self,
chat_peer: Optional[types.TypePeer] = None,
msg_id: Optional[int] = None,
broadcast: Optional[bool] = None):
super().__init__(chat_peer, broadcast=broadcast)
self._entities = {}
self._client = None
self._message_id = msg_id
self.original_update = None # type: Optional[types.TypeUpdate]
def _set_client(self, client: TelegramClient):
def _set_client(self, client: 'TelegramClient'):
"""
Setter so subclasses can act accordingly when the client is set.
"""
@ -161,7 +170,7 @@ class EventCommon(ChatGetter, abc.ABC):
self._chat = self._input_chat = None
@property
def client(self) -> TelegramClient:
def client(self) -> 'TelegramClient':
"""
The `telethon.TelegramClient` that created this event.
"""

View File

@ -49,9 +49,11 @@ class InlineQuery(EventBuilder):
])
"""
def __init__(
self, users: Optional[Sequence[hints.EntityLike]] = None, *,
blacklist_users: bool = False, func: Optional[Callable[['InlineQuery.Event'], None]] = None,
pattern: Union[str, Callable, Pattern, Optional] = None):
self,
users: Optional[Sequence[hints.EntityLike]] = None, *,
blacklist_users: bool = False,
func: Optional[Callable[['InlineQuery.Event'], None]] = None,
pattern: Union[str, Callable, Pattern, None] = None):
super().__init__(users, blacklist_chats=blacklist_users, func=func)
if isinstance(pattern, str):

View File

@ -1,3 +1,5 @@
from typing import Sequence, Optional
from .common import EventBuilder, EventCommon, name_inner_event
from ..tl import types
@ -36,7 +38,7 @@ class MessageDeleted(EventBuilder):
print('Message', msg_id, 'was deleted in', event.chat_id)
"""
@classmethod
def build(cls, update, others=None, self_id=None):
def build(cls, update: types.TypeUpdate, others=None, self_id=None):
if isinstance(update, types.UpdateDeleteMessages):
return cls.Event(
deleted_ids=update.messages,
@ -49,7 +51,9 @@ class MessageDeleted(EventBuilder):
)
class Event(EventCommon):
def __init__(self, deleted_ids, peer):
def __init__(self,
deleted_ids: Optional[Sequence[int]],
peer: Optional[types.TypePeer]):
super().__init__(
chat_peer=peer, msg_id=(deleted_ids or [0])[0]
)

View File

@ -43,7 +43,7 @@ class MessageEdited(NewMessage):
print('Message', event.id, 'changed at', event.date)
"""
@classmethod
def build(cls, update, others=None, self_id=None):
def build(cls, update: types.TypeUpdate, others=None, self_id=None):
if isinstance(update, (types.UpdateEditMessage,
types.UpdateEditChannelMessage)):
return cls.Event(update.message)

View File

@ -1,5 +1,7 @@
from typing import Optional, Sequence, Callable
from .common import EventBuilder, EventCommon, name_inner_event
from .. import utils
from .. import utils, hints
from ..tl import types
@ -29,13 +31,16 @@ class MessageRead(EventBuilder):
# Log when you read message in a chat (from your "inbox")
print('You have read messages until', event.max_id)
"""
def __init__(
self, chats=None, *, blacklist_chats=False, func=None, inbox=False):
def __init__(self,
chats: Optional[Sequence[hints.Entity]] = None, *,
blacklist_chats: bool = False,
func: Optional[Callable[['MessageRead.Event'], None]] = None,
inbox: bool = False):
super().__init__(chats, blacklist_chats=blacklist_chats, func=func)
self.inbox = inbox
@classmethod
def build(cls, update, others=None, self_id=None):
def build(cls, update: types.TypeUpdate, others=None, self_id=None):
if isinstance(update, types.UpdateReadHistoryInbox):
return cls.Event(update.peer, update.max_id, False)
elif isinstance(update, types.UpdateReadHistoryOutbox):
@ -54,7 +59,7 @@ class MessageRead(EventBuilder):
message_ids=update.messages,
contents=True)
def filter(self, event):
def filter(self, event: 'MessageRead.Event'):
if self.inbox == event.outbox:
return
@ -77,8 +82,12 @@ class MessageRead(EventBuilder):
This will be the case when e.g. you play a voice note.
It may only be set on ``inbox`` events.
"""
def __init__(self, peer=None, max_id=None, out=False, contents=False,
message_ids=None):
def __init__(self,
peer: Optional[types.TypePeer] = None,
max_id: Optional[int] = None,
out: bool = False,
contents: bool = False,
message_ids: Optional[Sequence[int]] = None):
self.outbox = out
self.contents = contents
self._message_ids = message_ids or []

View File

@ -61,7 +61,7 @@ class NewMessage(EventBuilder):
func: Optional[Callable[['NewMessage.Event'], None]] = None,
incoming: Optional[bool] = None, outgoing: Optional[bool] = None,
from_users: Optional[hints.Entity] = None, forwards: Optional[bool] = None,
pattern: Union[str, Callable, Pattern, Optional] = None):
pattern: Union[str, Callable, Pattern, None] = None):
if incoming and outgoing:
incoming = outgoing = None # Same as no filter
elif incoming is not None and outgoing is None:
@ -170,7 +170,7 @@ class NewMessage(EventBuilder):
return super().filter(event)
class Event(EventCommon, types.TypeMessage):
class Event(EventCommon):
"""
Represents the event of a new message. This event can be treated
to all effects as a `Message <telethon.tl.custom.message.Message>`,

View File

@ -0,0 +1,6 @@
from telethon.events.common import EventCommon
from telethon.tl.custom import Message
class NewMessage:
class Event(EventCommon, Message): ...

View File

@ -1,8 +1,12 @@
import abc
from typing import Optional, TYPE_CHECKING
from ... import errors, utils
from ...tl import types
if TYPE_CHECKING:
from ... import hints
class ChatGetter(abc.ABC):
"""
@ -10,7 +14,12 @@ class ChatGetter(abc.ABC):
and `chat_id` properties and `get_chat` and `get_input_chat`
methods.
"""
def __init__(self, chat_peer=None, *, input_chat=None, chat=None, broadcast=None):
def __init__(self,
chat_peer: Optional[types.TypePeer] = None,
*,
input_chat: Optional[types.Chat] = None,
chat: Optional[types.Chat] = None,
broadcast: Optional[bool] = None):
self._chat_peer = chat_peer
self._input_chat = input_chat
self._chat = chat
@ -18,7 +27,7 @@ class ChatGetter(abc.ABC):
self._client = None
@property
def chat(self):
def chat(self) -> Optional['hints.Entity']:
"""
Returns the :tl:`User`, :tl:`Chat` or :tl:`Channel` where this object
belongs to. It may be `None` if Telegram didn't send the chat.
@ -32,7 +41,7 @@ class ChatGetter(abc.ABC):
"""
return self._chat
async def get_chat(self):
async def get_chat(self) -> Optional['hints.Entity']:
"""
Returns `chat`, but will make an API call to find the
chat unless it's already cached.
@ -53,7 +62,7 @@ class ChatGetter(abc.ABC):
return self._chat
@property
def input_chat(self):
def input_chat(self) -> Optional[types.TypeInputPeer]:
"""
This :tl:`InputPeer` is the input version of the chat where the
message was sent. Similarly to `input_sender
@ -72,7 +81,7 @@ class ChatGetter(abc.ABC):
return self._input_chat
async def get_input_chat(self):
async def get_input_chat(self) -> Optional[types.TypeInputPeer]:
"""
Returns `input_chat`, but will make an API call to find the
input chat unless it's already cached.
@ -92,7 +101,7 @@ class ChatGetter(abc.ABC):
return self._input_chat
@property
def chat_id(self):
def chat_id(self) -> Optional[int]:
"""
Returns the marked chat integer ID. Note that this value **will
be different** from ``peer_id`` for incoming private messages, since
@ -107,7 +116,7 @@ class ChatGetter(abc.ABC):
return utils.get_peer_id(self._chat_peer) if self._chat_peer else None
@property
def is_private(self):
def is_private(self) -> Optional[bool]:
"""
`True` if the message was sent as a private message.
@ -117,7 +126,7 @@ class ChatGetter(abc.ABC):
return isinstance(self._chat_peer, types.PeerUser) if self._chat_peer else None
@property
def is_group(self):
def is_group(self) -> Optional[bool]:
"""
True if the message was sent on a group or megagroup.
@ -137,7 +146,7 @@ class ChatGetter(abc.ABC):
return isinstance(self._chat_peer, types.PeerChat)
@property
def is_channel(self):
def is_channel(self) -> bool:
"""`True` if the message was sent on a megagroup or channel."""
# The only case where chat peer could be none is in MessageDeleted,
# however those always have the peer in channels.