From 6626fbcd0cc2c63034429b38a87098e843acd302 Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Mon, 19 Aug 2024 12:21:50 +0500 Subject: [PATCH] Fix async filter call in combinators --- .../client/events/filters/combinators.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/client/src/telethon/_impl/client/events/filters/combinators.py b/client/src/telethon/_impl/client/events/filters/combinators.py index 10b59187..96b8a188 100644 --- a/client/src/telethon/_impl/client/events/filters/combinators.py +++ b/client/src/telethon/_impl/client/events/filters/combinators.py @@ -1,6 +1,7 @@ import abc import typing from collections.abc import Callable +from inspect import isawaitable from typing import Awaitable, TypeAlias from ..event import Event @@ -41,7 +42,7 @@ class Combinable(abc.ABC): return self.filter if isinstance(self, Not) else Not(self) # type: ignore [return-value] @abc.abstractmethod - def __call__(self, event: Event) -> bool: + async def __call__(self, event: Event) -> bool: pass @@ -81,8 +82,10 @@ class Any(Combinable): """ return self._filters - def __call__(self, event: Event) -> bool: - return any(f(event) for f in self._filters) + async def __call__(self, event: Event) -> bool: + return any( + [await r if isawaitable(r := f(event)) else r for f in self._filters] + ) class All(Combinable): @@ -121,8 +124,10 @@ class All(Combinable): """ return self._filters - def __call__(self, event: Event) -> bool: - return all(f(event) for f in self._filters) + async def __call__(self, event: Event) -> bool: + return all( + [await r if isawaitable(r := f(event)) else r for f in self._filters] + ) class Not(Combinable): @@ -159,5 +164,5 @@ class Not(Combinable): """ return self._filter - def __call__(self, event: Event) -> bool: - return not self._filter(event) + async def __call__(self, event: Event) -> bool: + return not (await r if isawaitable(r := self._filter(event)) else r)