Properly resolve events

This commit is contained in:
Lonami Exo 2018-08-21 12:14:32 +02:00
parent 47190d7d55
commit d3a6822fc9
4 changed files with 36 additions and 14 deletions

View File

@ -89,7 +89,7 @@ class UpdateMethods(UserMethods):
elif not event: elif not event:
event = events.Raw() event = events.Raw()
self._loop.create_task(event.resolve(self)) event.ensure_resolve(self)
self._event_builders.append((event, callback)) self._event_builders.append((event, callback))
def remove_event_handler(self, callback, event=None): def remove_event_handler(self, callback, event=None):
@ -266,10 +266,8 @@ class UpdateMethods(UserMethods):
if not event: if not event:
continue continue
# TODO Lock until it's resolved; the task for resolving if not builder.resolved.is_set():
# was already created when adding the event handler. await builder.resolved.wait()
if not builder.resolved:
await builder.resolve()
if not builder.filter(event): if not builder.filter(event):
continue continue

View File

@ -1,4 +1,5 @@
import abc import abc
import asyncio
import warnings import warnings
from .. import utils from .. import utils
@ -57,26 +58,48 @@ class EventBuilder(abc.ABC):
def __init__(self, chats=None, blacklist_chats=False): def __init__(self, chats=None, blacklist_chats=False):
self.chats = chats self.chats = chats
self.blacklist_chats = blacklist_chats self.blacklist_chats = blacklist_chats
self.resolved = False self.resolved = None
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def build(cls, update): def build(cls, update):
"""Builds an event for the given update if possible, or returns None""" """Builds an event for the given update if possible, or returns None"""
def ensure_resolve(self, client):
"""
Sets the event loop so that self.resolved can be used.
The expected workflow is:
1. Creating the event builder.
2a. Calling `ensure_resolve`.
2b. Awaiting `resolved.wait`.
OR
2a. Awaiting `resolve`.
3. Using `filter`.
"""
if not self.resolved:
self.resolved = asyncio.Event(loop=client.loop)
client.loop.create_task(self.resolve(client))
async def resolve(self, client): async def resolve(self, client):
"""Helper method to allow event builders to be resolved before usage""" """Helper method to allow event builders to be resolved before usage"""
if not self.resolved: if not self.resolved.is_set():
self.resolved = True
self.chats = await _into_id_set(client, self.chats) self.chats = await _into_id_set(client, self.chats)
if not EventBuilder.self_id: if not EventBuilder.self_id:
EventBuilder.self_id = await client.get_peer_id('me') EventBuilder.self_id = await client.get_peer_id('me')
self.resolved.set()
def filter(self, event): def filter(self, event):
""" """
If the ID of ``event._chat_peer`` isn't in the chats set (or it is If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns ``None``, otherwise the event. but the set is a blacklist) returns ``None``, otherwise the event.
The events must have been resolved before this can be called.
""" """
if not self.resolved:
return None
if self.chats is not None: if self.chats is not None:
inside = utils.get_peer_id(event._chat_peer) in self.chats inside = utils.get_peer_id(event._chat_peer) in self.chats
if inside == self.blacklist_chats: if inside == self.blacklist_chats:

View File

@ -1,7 +1,8 @@
import asyncio
import re import re
from .common import EventBuilder, EventCommon, name_inner_event, _into_id_set from .common import EventBuilder, EventCommon, name_inner_event, _into_id_set
from ..tl import types, custom from ..tl import types
@name_inner_event @name_inner_event
@ -71,9 +72,9 @@ class NewMessage(EventBuilder):
)) ))
async def resolve(self, client): async def resolve(self, client):
if not self.resolved: if not self.resolved.is_set():
await super().resolve(client)
self.from_users = await _into_id_set(client, self.from_users) self.from_users = await _into_id_set(client, self.from_users)
await super().resolve(client)
@classmethod @classmethod
def build(cls, update): def build(cls, update):

View File

@ -260,6 +260,9 @@ class Conversation(ChatGetter):
if isinstance(event, type): if isinstance(event, type):
event = event() event = event()
# Since we await resolve here we don't need to await resolved.
# We know it has already been resolved, unlike when normally
# adding an event handler, for which a task is created to resolve.
await event.resolve() await event.resolve()
counter = Conversation._custom_counter counter = Conversation._custom_counter
@ -276,9 +279,6 @@ class Conversation(ChatGetter):
return await result() return await result()
async def _check_custom(self, built): async def _check_custom(self, built):
# TODO This code is quite much a copy paste of registering events
# in the client, resolving them and setting the client; perhaps
# there is a better way?
for i, (ev, fut) in self._custom.items(): for i, (ev, fut) in self._custom.items():
ev_type = type(ev) ev_type = type(ev)
if built[ev_type] and ev.filter(built[ev_type]): if built[ev_type] and ev.filter(built[ev_type]):