Fix async filter call in combinators

This commit is contained in:
Jahongir Qurbonov 2024-08-19 12:21:50 +05:00
parent eb7ed5dd31
commit 6626fbcd0c

View File

@ -1,6 +1,7 @@
import abc import abc
import typing import typing
from collections.abc import Callable from collections.abc import Callable
from inspect import isawaitable
from typing import Awaitable, TypeAlias from typing import Awaitable, TypeAlias
from ..event import Event from ..event import Event
@ -41,7 +42,7 @@ class Combinable(abc.ABC):
return self.filter if isinstance(self, Not) else Not(self) # type: ignore [return-value] return self.filter if isinstance(self, Not) else Not(self) # type: ignore [return-value]
@abc.abstractmethod @abc.abstractmethod
def __call__(self, event: Event) -> bool: async def __call__(self, event: Event) -> bool:
pass pass
@ -81,8 +82,10 @@ class Any(Combinable):
""" """
return self._filters return self._filters
def __call__(self, event: Event) -> bool: async def __call__(self, event: Event) -> bool:
return any(f(event) for f in self._filters) return any(
[await r if isawaitable(r := f(event)) else r for f in self._filters]
)
class All(Combinable): class All(Combinable):
@ -121,8 +124,10 @@ class All(Combinable):
""" """
return self._filters return self._filters
def __call__(self, event: Event) -> bool: async def __call__(self, event: Event) -> bool:
return all(f(event) for f in self._filters) return all(
[await r if isawaitable(r := f(event)) else r for f in self._filters]
)
class Not(Combinable): class Not(Combinable):
@ -159,5 +164,5 @@ class Not(Combinable):
""" """
return self._filter return self._filter
def __call__(self, event: Event) -> bool: async def __call__(self, event: Event) -> bool:
return not self._filter(event) return not (await r if isawaitable(r := self._filter(event)) else r)