Implement update dispatching

This commit is contained in:
Lonami Exo 2023-09-03 18:47:47 +02:00
parent 8727b87130
commit 4ef3e63a88
15 changed files with 570 additions and 46 deletions

View File

@ -4,12 +4,17 @@ from collections import deque
from pathlib import Path
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Deque,
Dict,
List,
Literal,
Optional,
Self,
Tuple,
Type,
TypeVar,
Union,
@ -18,8 +23,11 @@ from typing import (
from ...mtsender import Sender
from ...session import ChatHashCache, MessageBox, PackedChat, Session
from ...tl import Request, abcs
from ..events import Event
from ..events.filters import Filter
from ..types import (
AsyncList,
Chat,
ChatLike,
File,
InFileLike,
@ -87,11 +95,10 @@ from .net import (
)
from .updates import (
add_event_handler,
catch_up,
list_event_handlers,
get_handler_filter,
on,
remove_event_handler,
set_receive_updates,
set_handler_filter,
)
from .users import (
get_entity,
@ -105,6 +112,7 @@ from .users import (
)
Return = TypeVar("Return")
T = TypeVar("T")
class Client:
@ -115,9 +123,15 @@ class Client:
self._config = config
self._message_box = MessageBox()
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn = None
self._updates: Deque[abcs.Update] = deque(maxlen=config.update_queue_limit)
self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[
Tuple[abcs.Update, Dict[int, Union[abcs.User, abcs.Chat]]]
] = asyncio.Queue(maxsize=config.update_queue_limit or 0)
self._dispatcher: Optional[asyncio.Task[None]] = None
self._downloader_map = object()
self._handlers: Dict[
Type[Event], List[Tuple[Callable[[Any], Awaitable[Any]], Optional[Filter]]]
] = {}
if self_user := config.session.user:
self._dc_id = self_user.dc
@ -127,8 +141,13 @@ class Client:
def action(self) -> None:
action(self)
def add_event_handler(self) -> None:
add_event_handler(self)
def add_event_handler(
self,
handler: Callable[[Event], Awaitable[Any]],
event_cls: Type[Event],
filter: Optional[Filter] = None,
) -> None:
add_event_handler(self, handler, event_cls, filter)
async def bot_sign_in(self, token: str) -> User:
return await bot_sign_in(self, token)
@ -136,9 +155,6 @@ class Client:
def build_reply_markup(self) -> None:
build_reply_markup(self)
async def catch_up(self) -> None:
await catch_up(self)
async def check_password(
self, token: PasswordToken, password: Union[str, bytes]
) -> User:
@ -213,6 +229,11 @@ class Client:
async def get_entity(self) -> None:
await get_entity(self)
def get_handler_filter(
self, handler: Callable[[Event], Awaitable[Any]]
) -> Optional[Filter]:
return get_handler_filter(self, handler)
async def get_input_entity(self) -> None:
await get_input_entity(self)
@ -283,17 +304,18 @@ class Client:
async def kick_participant(self) -> None:
await kick_participant(self)
def list_event_handlers(self) -> None:
list_event_handlers(self)
def on(self) -> None:
on(self)
def on(
self, event_cls: Type[Event], filter: Optional[Filter] = None
) -> Callable[
[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
]:
return on(self, event_cls, filter)
async def pin_message(self, chat: ChatLike, message_id: int) -> Message:
return await pin_message(self, chat, message_id)
def remove_event_handler(self) -> None:
remove_event_handler(self)
def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]]) -> None:
remove_event_handler(self, handler)
async def request_login_code(self, phone: str) -> LoginToken:
return await request_login_code(self, phone)
@ -524,8 +546,12 @@ class Client:
supports_streaming=supports_streaming,
)
async def set_receive_updates(self) -> None:
await set_receive_updates(self)
def set_handler_filter(
self,
handler: Callable[[Event], Awaitable[Any]],
filter: Optional[Filter] = None,
) -> None:
set_handler_filter(self, handler, filter)
async def sign_in(self, token: LoginToken, code: str) -> Union[User, PasswordToken]:
return await sign_in(self, token, code)

View File

@ -13,6 +13,7 @@ from ...mtsender import connect as connect_without_auth
from ...mtsender import connect_with_auth
from ...session import DataCenter, Session
from ...tl import LAYER, Request, functions
from .updates import dispatcher, process_socket_updates
if TYPE_CHECKING:
from .client import Client
@ -118,6 +119,9 @@ async def connect_sender(dc_id: int, config: Config) -> Sender:
async def connect(self: Client) -> None:
if self._sender:
return
self._sender = await connect_sender(self._dc_id, self._config)
if self._message_box.is_empty() and self._config.session.user:
@ -129,11 +133,18 @@ async def connect(self: Client) -> None:
except Exception as e:
pass
self._dispatcher = asyncio.create_task(dispatcher(self))
async def disconnect(self: Client) -> None:
if not self._sender:
return
assert self._dispatcher
self._dispatcher.cancel()
await self._dispatcher
self._dispatcher = None
await self._sender.disconnect()
self._sender = None
@ -181,7 +192,7 @@ async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> Non
else:
async with lock:
updates = await sender.step()
# client._process_socket_updates(updates)
process_socket_updates(client, updates)
async def run_until_disconnected(self: Client) -> None:

View File

@ -1,36 +1,143 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import asyncio
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)
from ...session import Gap
from ...tl import abcs
from ..events import Event as EventBase
from ..events.filters import Filter
if TYPE_CHECKING:
from .client import Client
Event = TypeVar("Event", bound=EventBase)
async def set_receive_updates(self: Client) -> None:
self
raise NotImplementedError
UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN = 300
def on(self: Client) -> None:
self
raise NotImplementedError
def on(
self: Client, event_cls: Type[Event], filter: Optional[Filter] = None
) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]:
def wrapper(
handler: Callable[[Event], Awaitable[Any]]
) -> Callable[[Event], Awaitable[Any]]:
add_event_handler(self, handler, event_cls, filter)
return handler
return wrapper
def add_event_handler(self: Client) -> None:
self
raise NotImplementedError
def add_event_handler(
self: Client,
handler: Callable[[Event], Awaitable[Any]],
event_cls: Type[Event],
filter: Optional[Filter] = None,
) -> None:
self._handlers.setdefault(event_cls, []).append((handler, filter))
def remove_event_handler(self: Client) -> None:
self
raise NotImplementedError
def remove_event_handler(
self: Client, handler: Callable[[Event], Awaitable[Any]]
) -> None:
for event_cls, handlers in tuple(self._handlers.items()):
for i in reversed(range(len(handlers))):
if handlers[i][0] == handler:
handlers.pop(i)
if not handlers:
del self._handlers[event_cls]
def list_event_handlers(self: Client) -> None:
self
raise NotImplementedError
def get_handler_filter(
self: Client, handler: Callable[[Event], Awaitable[Any]]
) -> Optional[Filter]:
for handlers in self._handlers.values():
for h, f in handlers:
if h == handler:
return f
return None
async def catch_up(self: Client) -> None:
self
raise NotImplementedError
def set_handler_filter(
self: Client,
handler: Callable[[Event], Awaitable[Any]],
filter: Optional[Filter] = None,
) -> None:
for handlers in self._handlers.values():
for i, (h, _) in enumerate(handlers):
if h == handler:
handlers[i] = (h, filter)
def process_socket_updates(client: Client, all_updates: List[abcs.Updates]) -> None:
if not all_updates:
return
for updates in all_updates:
try:
client._message_box.ensure_known_peer_hashes(updates, client._chat_hashes)
except Gap:
return
try:
result, users, chats = client._message_box.process_updates(
updates, client._chat_hashes
)
except Gap:
return
extend_update_queue(client, result, users, chats)
def extend_update_queue(
client: Client,
updates: List[abcs.Update],
users: List[abcs.User],
chats: List[abcs.Chat],
) -> None:
entities: Dict[int, Union[abcs.User, abcs.Chat]] = {
getattr(u, "id", None) or 0: u for u in users
}
entities.update({getattr(c, "id", None) or 0: c for c in chats})
for update in updates:
try:
client._updates.put_nowait((update, entities))
except asyncio.QueueFull:
now = asyncio.get_running_loop().time()
if client._last_update_limit_warn is None or (
now - client._last_update_limit_warn
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
):
# TODO warn
client._last_update_limit_warn = now
break
async def dispatcher(client: Client) -> None:
while client.connected:
update, entities = await client._updates.get()
for event_cls, handlers in client._handlers.items():
if event := event_cls._try_from_update(client, update):
for handler, filter in handlers:
if not filter or filter(event):
try:
await handler(event)
except asyncio.CancelledError:
raise
except Exception:
# TODO proper logger
name = getattr(handler, "__name__", repr(handler))
logging.exception("Unhandled exception on %s", name)

View File

@ -55,8 +55,8 @@ async def resolve_to_packed(self: Client, chat: ChatLike) -> PackedChat:
raise ValueError("Cannot resolve chat")
return PackedChat(
ty=PackedType.BOT if self._config.session.user.bot else PackedType.USER,
id=self._config.session.user.id,
access_hash=0,
id=self._chat_hashes.self_id,
access_hash=0, # TODO get hash
)
elif isinstance(chat, types.InputPeerChat):
return PackedChat(
@ -94,11 +94,7 @@ def input_to_peer(
elif isinstance(input, types.InputPeerEmpty):
return None
elif isinstance(input, types.InputPeerSelf):
return (
types.PeerUser(user_id=client._config.session.user.id)
if client._config.session.user
else None
)
return types.PeerUser(user_id=client._chat_hashes.self_id)
elif isinstance(input, types.InputPeerChat):
return types.PeerChat(chat_id=input.chat_id)
elif isinstance(input, types.InputPeerUser):

View File

@ -0,0 +1,13 @@
from .event import Event
from .messages import MessageDeleted, MessageEdited, MessageRead, NewMessage
from .queries import CallbackQuery, InlineQuery
__all__ = [
"Event",
"MessageDeleted",
"MessageEdited",
"MessageRead",
"NewMessage",
"CallbackQuery",
"InlineQuery",
]

View File

@ -0,0 +1,17 @@
from __future__ import annotations
import abc
from typing import TYPE_CHECKING, Optional, Self
from ...tl import abcs
from ..types.meta import NoPublicConstructor
if TYPE_CHECKING:
from ..client.client import Client
class Event(metaclass=NoPublicConstructor):
@classmethod
@abc.abstractmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
pass

View File

@ -0,0 +1,18 @@
from .combinators import All, Any, Not
from .common import Chats, Filter, Senders
from .messages import Command, Forward, Incoming, Outgoing, Reply, Text
__all__ = [
"All",
"Any",
"Not",
"Chats",
"Filter",
"Senders",
"Command",
"Forward",
"Incoming",
"Outgoing",
"Reply",
"Text",
]

View File

@ -0,0 +1,68 @@
from typing import Tuple
from ..event import Event
from .common import Filter
class Any:
"""
Combine multiple filters, returning `True` if any of the filters pass.
"""
__slots__ = ("_filters",)
def __init__(self, filter1: Filter, filter2: Filter, *filters: Filter) -> None:
self._filters = (filter1, filter2, *filters)
@property
def filters(self) -> Tuple[Filter, ...]:
"""
The filters being checked, in order.
"""
return self._filters
def __call__(self, event: Event) -> bool:
return any(f(event) for f in self._filters)
class All:
"""
Combine multiple filters, returning `True` if all of the filters pass.
"""
__slots__ = ("_filters",)
def __init__(self, filter1: Filter, filter2: Filter, *filters: Filter) -> None:
self._filters = (filter1, filter2, *filters)
@property
def filters(self) -> Tuple[Filter, ...]:
"""
The filters being checked, in order.
"""
return self._filters
def __call__(self, event: Event) -> bool:
return all(f(event) for f in self._filters)
class Not:
"""
Negate the output of a single filter, returning `True` if the nested
filter does *not* pass.
"""
__slots__ = ("_filter",)
def __init__(self, filter: Filter) -> None:
self._filter = filter
@property
def filter(self) -> Filter:
"""
The filters being negated.
"""
return self._filter
def __call__(self, event: Event) -> bool:
return not self._filter(event)

View File

@ -0,0 +1,53 @@
from typing import Callable, Sequence, Tuple, Union
from ..event import Event
Filter = Callable[[Event], bool]
class Chats:
"""
Filter by `event.chat.id`.
"""
__slots__ = ("_chats",)
def __init__(self, chat_id: Union[int, Sequence[int]], *chat_ids: int) -> None:
self._chats = {chat_id} if isinstance(chat_id, int) else set(chat_id)
self._chats.update(chat_ids)
@property
def chat_ids(self) -> Tuple[int, ...]:
"""
The chat identifiers this filter is filtering on.
"""
return tuple(self._chats)
def __call__(self, event: Event) -> bool:
chat = getattr(event, "chat", None)
id = getattr(chat, "id", None)
return id in self._chats
class Senders:
"""
Filter by `event.sender.id`.
"""
__slots__ = ("_senders",)
def __init__(self, sender_id: Union[int, Sequence[int]], *sender_ids: int) -> None:
self._senders = {sender_id} if isinstance(sender_id, int) else set(sender_id)
self._senders.update(sender_ids)
@property
def sender_ids(self) -> Tuple[int, ...]:
"""
The sender identifiers this filter is filtering on.
"""
return tuple(self._senders)
def __call__(self, event: Event) -> bool:
sender = getattr(event, "sender", None)
id = getattr(sender, "id", None)
return id in self._senders

View File

@ -0,0 +1,99 @@
import re
from typing import Union
from ..event import Event
class Text:
"""
Filter by `event.text` using a *regular expression* pattern.
The pattern is searched on the text anywhere, not matched at the start.
Use the `'^'` anchor if you want to match the text from the start.
The match, if any, is discarded. If you need to access captured groups,
you need to manually perform the check inside the handler instead.
"""
__slots__ = ("_pattern",)
def __init__(self, regexp: Union[str, re.Pattern[str]]) -> None:
self._pattern = re.compile(regexp) if isinstance(regexp, str) else regexp
def __call__(self, event: Event) -> bool:
text = getattr(event, "text", None)
return re.search(self._pattern, text) is not None if text is not None else False
class Command:
"""
Filter by `event.text` to make sure the first word matches the command or
the command + '@' + username, using the username of the logged-in account.
For example, if the logged-in account has an username of "bot", then the
filter `Command('/help')` will match both "/help" and "/help@bot", but not
"/list" or "/help@other".
Note that the leading forward-slash is not automatically added,
which allows for using a different prefix or no prefix at all.
"""
__slots__ = ("_cmd",)
def __init__(self, command: str) -> None:
self._cmd = command
def __call__(self, event: Event) -> bool:
raise NotImplementedError
class Incoming:
"""
Filter by `event.incoming`, that is, messages sent from others to the
logged-in account.
This is not a reliable way to check that the update was not produced by
the logged-in account.
"""
__slots__ = ()
def __call__(self, event: Event) -> bool:
return getattr(event, "incoming", False)
class Outgoing:
"""
Filter by `event.outgoing`, that is, messages sent from others to the
logged-in account.
This is not a reliable way to check that the update was not produced by
the logged-in account.
"""
__slots__ = ()
def __call__(self, event: Event) -> bool:
return getattr(event, "outgoing", False)
class Forward:
"""
Filter by `event.forward`.
"""
__slots__ = ()
def __call__(self, event: Event) -> bool:
return getattr(event, "forward", None) is not None
class Reply:
"""
Filter by `event.reply`.
"""
__slots__ = ()
def __call__(self, event: Event) -> bool:
return getattr(event, "reply", None) is not None

View File

@ -0,0 +1,46 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Self
from ...session.message_box.adaptor import (
update_short_chat_message,
update_short_message,
)
from ...tl import abcs, types
from ..types import Message
from .event import Event
if TYPE_CHECKING:
from ..client.client import Client
class NewMessage(Event, Message):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if isinstance(update.message, types.Message):
return cls._from_raw(update.message)
elif isinstance(
update, (types.UpdateShortMessage, types.UpdateShortChatMessage)
):
raise RuntimeError("should have been handled by adaptor")
return None
class MessageEdited(Event):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
raise NotImplementedError()
class MessageDeleted(Event):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
raise NotImplementedError()
class MessageRead(Event):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
raise NotImplementedError()

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Self
from ...tl import abcs
from .event import Event
if TYPE_CHECKING:
from ..client.client import Client
class CallbackQuery(Event):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
raise NotImplementedError()
class InlineQuery(Event):
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update) -> Optional[Self]:
raise NotImplementedError()

View File

@ -240,6 +240,7 @@ class MessageBox:
or pts_info_from_update(updates.update) is not None
)
if can_recover:
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
raise Gap
# https://core.telegram.org/api/updates

View File

@ -0,0 +1,19 @@
from .._impl.client.events import (
CallbackQuery,
Event,
InlineQuery,
MessageDeleted,
MessageEdited,
MessageRead,
NewMessage,
)
__all__ = [
"CallbackQuery",
"Event",
"InlineQuery",
"MessageDeleted",
"MessageEdited",
"MessageRead",
"NewMessage",
]

View File

@ -0,0 +1,29 @@
from .._impl.client.events.filters import (
All,
Any,
Chats,
Command,
Filter,
Forward,
Incoming,
Not,
Outgoing,
Reply,
Senders,
Text,
)
__all__ = [
"All",
"Any",
"Chats",
"Command",
"Filter",
"Forward",
"Incoming",
"Not",
"Outgoing",
"Reply",
"Senders",
"Text",
]