Enable awaitable event builder func

This commit is contained in:
Kirill Sorokin 2019-08-12 09:54:52 +03:00
parent e24dd3ad75
commit c3dee9697f
7 changed files with 23 additions and 16 deletions

View File

@ -415,7 +415,7 @@ class UpdateMethods:
if not builder.resolved: if not builder.resolved:
await builder.resolve(self) await builder.resolve(self)
if not builder.filter(event): if not await builder.filter(event):
continue continue
try: try:

View File

@ -74,7 +74,7 @@ class CallbackQuery(EventBuilder):
peer = types.PeerChannel(-pid) if pid < 0 else types.PeerUser(pid) peer = types.PeerChannel(-pid) if pid < 0 else types.PeerUser(pid)
return cls.Event(update, peer, mid) return cls.Event(update, peer, mid)
def filter(self, event): async def filter(self, event):
# We can't call super().filter(...) because it ignores chat_instance # We can't call super().filter(...) because it ignores chat_instance
if self._no_check: if self._no_check:
return event return event
@ -95,7 +95,7 @@ class CallbackQuery(EventBuilder):
elif event.query.data != self.match: elif event.query.data != self.match:
return return
if not self.func or self.func(event): if await self._check_func(event):
return event return event
class Event(EventCommon, SenderGetter): class Event(EventCommon, SenderGetter):

View File

@ -2,6 +2,7 @@ import abc
import asyncio import asyncio
import itertools import itertools
import warnings import warnings
import inspect
from .. import utils from .. import utils
from ..tl import TLObject, types, functions from ..tl import TLObject, types, functions
@ -103,7 +104,7 @@ class EventBuilder(abc.ABC):
async def _resolve(self, client): async def _resolve(self, client):
self.chats = await _into_id_set(client, self.chats) self.chats = await _into_id_set(client, self.chats)
def filter(self, event): async def filter(self, event):
""" """
If the ID of ``event._chat_peer`` isn't in the chats set (or it is If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns `None`, otherwise the event. but the set is a blacklist) returns `None`, otherwise the event.
@ -121,9 +122,15 @@ class EventBuilder(abc.ABC):
# If it doesn't match but it's a whitelist ignore. # If it doesn't match but it's a whitelist ignore.
return None return None
if not self.func or self.func(event): if await self._check_func(event):
return event return event
async def _check_func(self, event):
if not self.func:
return True
r = self.func(event)
return await r if inspect.isawaitable(r) else r
class EventCommon(ChatGetter, abc.ABC): class EventCommon(ChatGetter, abc.ABC):
""" """

View File

@ -50,14 +50,14 @@ class InlineQuery(EventBuilder):
if isinstance(update, types.UpdateBotInlineQuery): if isinstance(update, types.UpdateBotInlineQuery):
return cls.Event(update) return cls.Event(update)
def filter(self, event): async def filter(self, event):
if self.pattern: if self.pattern:
match = self.pattern(event.text) match = self.pattern(event.text)
if not match: if not match:
return return
event.pattern_match = match event.pattern_match = match
return super().filter(event) return await super().filter(event)
class Event(EventCommon, SenderGetter): class Event(EventCommon, SenderGetter):
""" """

View File

@ -39,11 +39,11 @@ class MessageRead(EventBuilder):
message_ids=update.messages, message_ids=update.messages,
contents=True) contents=True)
def filter(self, event): async def filter(self, event):
if self.inbox == event.outbox: if self.inbox == event.outbox:
return return
return super().filter(event) return await super().filter(event)
class Event(EventCommon): class Event(EventCommon):
""" """

View File

@ -129,7 +129,7 @@ class NewMessage(EventBuilder):
return event return event
def filter(self, event): async def filter(self, event):
if self._no_check: if self._no_check:
return event return event
@ -151,7 +151,7 @@ class NewMessage(EventBuilder):
return return
event.pattern_match = match event.pattern_match = match
return super().filter(event) return await super().filter(event)
class Event(EventCommon): class Event(EventCommon):
""" """

View File

@ -35,7 +35,7 @@ class Raw(EventBuilder):
def build(cls, update, others=None, self_id=None): def build(cls, update, others=None, self_id=None):
return update return update
def filter(self, event): async def filter(self, event):
if ((not self.types or isinstance(event, self.types)) if ((not self.types or isinstance(event, self.types))
and (not self.func or self.func(event))): and await self._check_func(event)):
return event return event