Allow event's func to be async (#1461)

Fixes #1344.
This commit is contained in:
JuniorJPDJ 2020-05-16 09:58:37 +02:00 committed by GitHub
parent c45f2e7c39
commit 634bc3a8bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 15 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import inspect
import itertools import itertools
import random import random
import time import time
@ -424,7 +425,10 @@ class UpdateMethods:
if not builder.resolved: if not builder.resolved:
await builder.resolve(self) await builder.resolve(self)
if not builder.filter(event): filter = builder.filter(event)
if inspect.isawaitable(filter):
filter = await filter
if not filter:
continue continue
try: try:

View File

@ -118,8 +118,10 @@ class CallbackQuery(EventBuilder):
elif event.query.data != self.match: elif event.query.data != self.match:
return return
if not self.func or self.func(event): if self.func:
return event # Return the result of func directly as it may need to be awaited
return self.func(event)
return True
class Event(EventCommon, SenderGetter): class Event(EventCommon, SenderGetter):
""" """

View File

@ -55,7 +55,7 @@ class EventBuilder(abc.ABC):
which will be ignored if ``blacklist_chats=True``. which will be ignored if ``blacklist_chats=True``.
func (`callable`, optional): func (`callable`, optional):
A callable function that should accept the event as input A callable (async or not) function that should accept the event as input
parameter, and return a value indicating whether the event parameter, and return a value indicating whether the event
should be dispatched or not (any truthy value will do, it should be dispatched or not (any truthy value will do, it
does not need to be a `bool`). It works like a custom filter: does not need to be a `bool`). It works like a custom filter:
@ -105,13 +105,13 @@ class EventBuilder(abc.ABC):
def filter(self, event): def filter(self, event):
""" """
If the ID of ``event._chat_peer`` isn't in the chats set (or it is Returns a truthy value if the event passed the filter and should be
but the set is a blacklist) returns `None`, otherwise the event. used, or falsy otherwise. The return value may need to be awaited.
The events must have been resolved before this can be called. The events must have been resolved before this can be called.
""" """
if not self.resolved: if not self.resolved:
return None return
if self.chats is not None: if self.chats is not None:
# Note: the `event.chat_id` property checks if it's `None` for us # Note: the `event.chat_id` property checks if it's `None` for us
@ -119,10 +119,13 @@ class EventBuilder(abc.ABC):
if inside == self.blacklist_chats: if inside == self.blacklist_chats:
# If this chat matches but it's a blacklist ignore. # If this chat matches but it's a blacklist ignore.
# If it doesn't match but it's a whitelist ignore. # If it doesn't match but it's a whitelist ignore.
return None return
if not self.func or self.func(event): if not self.func:
return event return True
# Return the result of func directly as it may need to be awaited
return self.func(event)
class EventCommon(ChatGetter, abc.ABC): class EventCommon(ChatGetter, abc.ABC):

View File

@ -46,6 +46,8 @@ class Raw(EventBuilder):
return update return update
def filter(self, event): def filter(self, event):
if ((not self.types or isinstance(event, self.types)) if not self.types or isinstance(event, self.types):
and (not self.func or self.func(event))): if self.func:
# Return the result of func directly as it may need to be awaited
return self.func(event)
return event return event

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import functools import functools
import inspect
import itertools import itertools
import time import time
@ -312,9 +313,15 @@ class Conversation(ChatGetter):
for key, (ev, fut) in list(self._custom.items()): for key, (ev, fut) in list(self._custom.items()):
ev_type = type(ev) ev_type = type(ev)
inst = built[ev_type] inst = built[ev_type]
if inst and ev.filter(inst):
fut.set_result(inst) if inst:
del self._custom[key] filter = ev.filter(inst)
if inspect.isawaitable(filter):
filter = await filter
if filter:
fut.set_result(inst)
del self._custom[key]
def _on_new_message(self, response): def _on_new_message(self, response):
response = response.message response = response.message