diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 854860fd..7019b465 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -1,4 +1,5 @@ import asyncio +import inspect import itertools import random import time @@ -424,7 +425,9 @@ class UpdateMethods: if not builder.resolved: await builder.resolve(self) - if not builder.filter(event): + filter = builder.filter(event) + filter = (await filter) if inspect.isawaitable(filter) else filter + if not filter: continue try: diff --git a/telethon/events/common.py b/telethon/events/common.py index 42586608..fd3315f5 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -55,7 +55,7 @@ class EventBuilder(abc.ABC): which will be ignored if ``blacklist_chats=True``. func (`callable`, optional): - A callable function that should accept the event as input + A callable (async or not) function that should accept the event as input parameter, and return a value indicating whether the event should be dispatched or not (any truthy value will do, it does not need to be a `bool`). It works like a custom filter: @@ -106,12 +106,13 @@ class EventBuilder(abc.ABC): def filter(self, event): """ If the ID of ``event._chat_peer`` isn't in the chats set (or it is - but the set is a blacklist) returns `None`, otherwise the event. + but the set is a blacklist) returns `True`, otherwise `False`. + May also return awaitable which awaits to bool-able value. The events must have been resolved before this can be called. """ if not self.resolved: - return None + return if self.chats is not None: # Note: the `event.chat_id` property checks if it's `None` for us @@ -119,10 +120,12 @@ class EventBuilder(abc.ABC): if inside == self.blacklist_chats: # If this chat matches but it's a blacklist ignore. # If it doesn't match but it's a whitelist ignore. - return None + return - if not self.func or self.func(event): - return event + if not self.func: + return True + + return self.func(event) class EventCommon(ChatGetter, abc.ABC): diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index c0d59e03..663acc08 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -1,5 +1,6 @@ import asyncio import functools +import inspect import itertools import time @@ -312,9 +313,13 @@ class Conversation(ChatGetter): for key, (ev, fut) in list(self._custom.items()): ev_type = type(ev) inst = built[ev_type] - if inst and ev.filter(inst): - fut.set_result(inst) - del self._custom[key] + + if inst: + filter = ev.filter(inst) + filter = (await filter) if inspect.isawaitable(filter) else filter + if filter: + fut.set_result(inst) + del self._custom[key] def _on_new_message(self, response): response = response.message