Add async function support for filtering events

This commit is contained in:
JuniorJPDJ 2020-05-15 04:21:23 +02:00
parent 393da7e57a
commit 57bb64e966
3 changed files with 21 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import inspect
import itertools import itertools
import random import random
import time import time
@ -424,7 +425,9 @@ class UpdateMethods:
if not builder.resolved: if not builder.resolved:
await builder.resolve(self) await builder.resolve(self)
if not builder.filter(event): filter = builder.filter(event)
filter = (await filter) if inspect.isawaitable(filter) else filter
if not filter:
continue continue
try: try:

View File

@ -55,7 +55,7 @@ class EventBuilder(abc.ABC):
which will be ignored if ``blacklist_chats=True``. which will be ignored if ``blacklist_chats=True``.
func (`callable`, optional): func (`callable`, optional):
A callable function that should accept the event as input A callable (async or not) function that should accept the event as input
parameter, and return a value indicating whether the event parameter, and return a value indicating whether the event
should be dispatched or not (any truthy value will do, it should be dispatched or not (any truthy value will do, it
does not need to be a `bool`). It works like a custom filter: does not need to be a `bool`). It works like a custom filter:
@ -106,12 +106,13 @@ class EventBuilder(abc.ABC):
def filter(self, event): def filter(self, event):
""" """
If the ID of ``event._chat_peer`` isn't in the chats set (or it is If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns `None`, otherwise the event. but the set is a blacklist) returns `True`, otherwise `False`.
May also return awaitable which awaits to bool-able value.
The events must have been resolved before this can be called. The events must have been resolved before this can be called.
""" """
if not self.resolved: if not self.resolved:
return None return
if self.chats is not None: if self.chats is not None:
# Note: the `event.chat_id` property checks if it's `None` for us # Note: the `event.chat_id` property checks if it's `None` for us
@ -119,10 +120,12 @@ class EventBuilder(abc.ABC):
if inside == self.blacklist_chats: if inside == self.blacklist_chats:
# If this chat matches but it's a blacklist ignore. # If this chat matches but it's a blacklist ignore.
# If it doesn't match but it's a whitelist ignore. # If it doesn't match but it's a whitelist ignore.
return None return
if not self.func or self.func(event): if not self.func:
return event return True
return self.func(event)
class EventCommon(ChatGetter, abc.ABC): class EventCommon(ChatGetter, abc.ABC):

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import functools import functools
import inspect
import itertools import itertools
import time import time
@ -312,9 +313,13 @@ class Conversation(ChatGetter):
for key, (ev, fut) in list(self._custom.items()): for key, (ev, fut) in list(self._custom.items()):
ev_type = type(ev) ev_type = type(ev)
inst = built[ev_type] inst = built[ev_type]
if inst and ev.filter(inst):
fut.set_result(inst) if inst:
del self._custom[key] filter = ev.filter(inst)
filter = (await filter) if inspect.isawaitable(filter) else filter
if filter:
fut.set_result(inst)
del self._custom[key]
def _on_new_message(self, response): def _on_new_message(self, response):
response = response.message response = response.message