Enable awaitable event builder func

This commit is contained in:
Kirill Sorokin 2019-08-12 09:54:52 +03:00
parent e24dd3ad75
commit c3dee9697f
7 changed files with 23 additions and 16 deletions

View File

@ -415,7 +415,7 @@ class UpdateMethods:
if not builder.resolved:
await builder.resolve(self)
if not builder.filter(event):
if not await builder.filter(event):
continue
try:

View File

@ -74,7 +74,7 @@ class CallbackQuery(EventBuilder):
peer = types.PeerChannel(-pid) if pid < 0 else types.PeerUser(pid)
return cls.Event(update, peer, mid)
def filter(self, event):
async def filter(self, event):
# We can't call super().filter(...) because it ignores chat_instance
if self._no_check:
return event
@ -95,7 +95,7 @@ class CallbackQuery(EventBuilder):
elif event.query.data != self.match:
return
if not self.func or self.func(event):
if await self._check_func(event):
return event
class Event(EventCommon, SenderGetter):
@ -110,7 +110,7 @@ class CallbackQuery(EventBuilder):
The object returned by the ``data=`` parameter
when creating the event builder, if any. Similar
to ``pattern_match`` for the new message event.
pattern_match (`obj`, optional):
Alias for ``data_match``.
"""

View File

@ -2,6 +2,7 @@ import abc
import asyncio
import itertools
import warnings
import inspect
from .. import utils
from ..tl import TLObject, types, functions
@ -103,7 +104,7 @@ class EventBuilder(abc.ABC):
async def _resolve(self, client):
self.chats = await _into_id_set(client, self.chats)
def filter(self, event):
async def filter(self, event):
"""
If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns `None`, otherwise the event.
@ -121,9 +122,15 @@ class EventBuilder(abc.ABC):
# If it doesn't match but it's a whitelist ignore.
return None
if not self.func or self.func(event):
if await self._check_func(event):
return event
async def _check_func(self, event):
if not self.func:
return True
r = self.func(event)
return await r if inspect.isawaitable(r) else r
class EventCommon(ChatGetter, abc.ABC):
"""

View File

@ -50,14 +50,14 @@ class InlineQuery(EventBuilder):
if isinstance(update, types.UpdateBotInlineQuery):
return cls.Event(update)
def filter(self, event):
async def filter(self, event):
if self.pattern:
match = self.pattern(event.text)
if not match:
return
event.pattern_match = match
return super().filter(event)
return await super().filter(event)
class Event(EventCommon, SenderGetter):
"""
@ -156,9 +156,9 @@ class InlineQuery(EventBuilder):
gallery (`bool`, optional):
Whether the results should show as a gallery (grid) or not.
next_offset (`str`, optional):
The offset the client will send when the user scrolls the
The offset the client will send when the user scrolls the
results and it repeats the request.
private (`bool`, optional):

View File

@ -39,11 +39,11 @@ class MessageRead(EventBuilder):
message_ids=update.messages,
contents=True)
def filter(self, event):
async def filter(self, event):
if self.inbox == event.outbox:
return
return super().filter(event)
return await super().filter(event)
class Event(EventCommon):
"""

View File

@ -129,7 +129,7 @@ class NewMessage(EventBuilder):
return event
def filter(self, event):
async def filter(self, event):
if self._no_check:
return event
@ -151,7 +151,7 @@ class NewMessage(EventBuilder):
return
event.pattern_match = match
return super().filter(event)
return await super().filter(event)
class Event(EventCommon):
"""

View File

@ -35,7 +35,7 @@ class Raw(EventBuilder):
def build(cls, update, others=None, self_id=None):
return update
def filter(self, event):
async def filter(self, event):
if ((not self.types or isinstance(event, self.types))
and (not self.func or self.func(event))):
and await self._check_func(event)):
return event