diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 0dff677a..57cd4962 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -89,7 +89,7 @@ class UpdateMethods(UserMethods): elif not event: event = events.Raw() - self._loop.create_task(event.resolve(self)) + event.ensure_resolve(self) self._event_builders.append((event, callback)) def remove_event_handler(self, callback, event=None): @@ -266,10 +266,8 @@ class UpdateMethods(UserMethods): if not event: continue - # TODO Lock until it's resolved; the task for resolving - # was already created when adding the event handler. - if not builder.resolved: - await builder.resolve() + if not builder.resolved.is_set(): + await builder.resolved.wait() if not builder.filter(event): continue diff --git a/telethon/events/common.py b/telethon/events/common.py index e1eb2ee6..7c706e66 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -1,4 +1,5 @@ import abc +import asyncio import warnings from .. import utils @@ -57,26 +58,48 @@ class EventBuilder(abc.ABC): def __init__(self, chats=None, blacklist_chats=False): self.chats = chats self.blacklist_chats = blacklist_chats - self.resolved = False + self.resolved = None @classmethod @abc.abstractmethod def build(cls, update): """Builds an event for the given update if possible, or returns None""" + def ensure_resolve(self, client): + """ + Sets the event loop so that self.resolved can be used. + + The expected workflow is: + 1. Creating the event builder. + 2a. Calling `ensure_resolve`. + 2b. Awaiting `resolved.wait`. + OR + 2a. Awaiting `resolve`. + 3. Using `filter`. + """ + if not self.resolved: + self.resolved = asyncio.Event(loop=client.loop) + client.loop.create_task(self.resolve(client)) + async def resolve(self, client): """Helper method to allow event builders to be resolved before usage""" - if not self.resolved: - self.resolved = True + if not self.resolved.is_set(): self.chats = await _into_id_set(client, self.chats) if not EventBuilder.self_id: EventBuilder.self_id = await client.get_peer_id('me') + self.resolved.set() + def filter(self, event): """ If the ID of ``event._chat_peer`` isn't in the chats set (or it is but the set is a blacklist) returns ``None``, otherwise the event. + + The events must have been resolved before this can be called. """ + if not self.resolved: + return None + if self.chats is not None: inside = utils.get_peer_id(event._chat_peer) in self.chats if inside == self.blacklist_chats: diff --git a/telethon/events/newmessage.py b/telethon/events/newmessage.py index 902a6f73..b2ee33e1 100644 --- a/telethon/events/newmessage.py +++ b/telethon/events/newmessage.py @@ -1,7 +1,8 @@ +import asyncio import re from .common import EventBuilder, EventCommon, name_inner_event, _into_id_set -from ..tl import types, custom +from ..tl import types @name_inner_event @@ -71,9 +72,9 @@ class NewMessage(EventBuilder): )) async def resolve(self, client): - if not self.resolved: - await super().resolve(client) + if not self.resolved.is_set(): self.from_users = await _into_id_set(client, self.from_users) + await super().resolve(client) @classmethod def build(cls, update): diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index bc32633a..c6c00bf9 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -260,6 +260,9 @@ class Conversation(ChatGetter): if isinstance(event, type): event = event() + # Since we await resolve here we don't need to await resolved. + # We know it has already been resolved, unlike when normally + # adding an event handler, for which a task is created to resolve. await event.resolve() counter = Conversation._custom_counter @@ -276,9 +279,6 @@ class Conversation(ChatGetter): return await result() async def _check_custom(self, built): - # TODO This code is quite much a copy paste of registering events - # in the client, resolving them and setting the client; perhaps - # there is a better way? for i, (ev, fut) in self._custom.items(): ev_type = type(ev) if built[ev_type] and ev.filter(built[ev_type]):