Merge branch 'event-reusing'

This commit is contained in:
Lonami Exo 2018-07-11 11:31:46 +02:00
commit 4f5c6f1006
11 changed files with 86 additions and 49 deletions

View File

@ -1,5 +1,6 @@
import abc import abc
import asyncio import asyncio
import collections
import inspect import inspect
import logging import logging
import platform import platform
@ -258,6 +259,11 @@ class TelegramBaseClient(abc.ABC):
self._events_pending_resolve = [] self._events_pending_resolve = []
self._event_resolve_lock = asyncio.Lock() self._event_resolve_lock = asyncio.Lock()
# Keep track of how many event builders there are for
# each type {type: count}. If there's at least one then
# the event will be built, and the same event be reused.
self._event_builders_count = collections.defaultdict(int)
# Default parse mode # Default parse mode
self._parse_mode = markdown self._parse_mode = markdown

View File

@ -90,6 +90,7 @@ class UpdateMethods(UserMethods):
event = events.Raw() event = events.Raw()
self._events_pending_resolve.append(event) self._events_pending_resolve.append(event)
self._event_builders_count[type(event)] += 1
self._event_builders.append((event, callback)) self._event_builders.append((event, callback))
def remove_event_handler(self, callback, event=None): def remove_event_handler(self, callback, event=None):
@ -108,6 +109,11 @@ class UpdateMethods(UserMethods):
i -= 1 i -= 1
ev, cb = self._event_builders[i] ev, cb = self._event_builders[i]
if cb == callback and (not event or isinstance(ev, event)): if cb == callback and (not event or isinstance(ev, event)):
type_ev = type(ev)
self._event_builders_count[type_ev] -= 1
if not self._event_builders_count[type_ev]:
del self._event_builders_count[type_ev]
del self._event_builders[i] del self._event_builders[i]
found += 1 found += 1
@ -257,28 +263,36 @@ class UpdateMethods(UserMethods):
self._events_pending_resolve.clear() self._events_pending_resolve.clear()
for builder, callback in self._event_builders: # TODO We can improve this further
event = builder.build(update) # If we had a way to get all event builders for
if event: # a type instead looping over them all always.
if hasattr(event, '_set_client'): built = {builder: builder.build(update)
event._set_client(self) for builder in self._event_builders_count}
else:
event._client = self
event.original_update = update for builder, callback in self._event_builders:
try: event = built[type(builder)]
await callback(event) if not event or not builder.filter(event):
except events.StopPropagation: continue
__log__.debug(
"Event handler '{}' stopped chain of " if hasattr(event, '_set_client'):
"propagation for event {}." event._set_client(self)
.format(callback.__name__, else:
type(event).__name__) event._client = self
)
break event.original_update = update
except Exception: try:
__log__.exception('Unhandled exception on {}' await callback(event)
.format(callback.__name__)) except events.StopPropagation:
__log__.debug(
"Event handler '{}' stopped chain of "
"propagation for event {}."
.format(callback.__name__,
type(event).__name__)
)
break
except Exception:
__log__.exception('Unhandled exception on {}'
.format(callback.__name__))
async def _handle_auto_reconnect(self): async def _handle_auto_reconnect(self):
# Upon reconnection, we want to send getState # Upon reconnection, we want to send getState

View File

@ -40,7 +40,8 @@ class CallbackQuery(EventBuilder):
else: else:
raise TypeError('Invalid data type given') raise TypeError('Invalid data type given')
def build(self, update): @staticmethod
def build(update):
if isinstance(update, (types.UpdateBotCallbackQuery, if isinstance(update, (types.UpdateBotCallbackQuery,
types.UpdateInlineBotCallbackQuery)): types.UpdateInlineBotCallbackQuery)):
event = CallbackQuery.Event(update) event = CallbackQuery.Event(update)
@ -48,9 +49,9 @@ class CallbackQuery(EventBuilder):
return return
event._entities = update._entities event._entities = update._entities
return self._filter_event(event) return event
def _filter_event(self, event): def filter(self, event):
if self.chats is not None: if self.chats is not None:
inside = event.query.chat_instance in self.chats inside = event.query.chat_instance in self.chats
if event.chat_id: if event.chat_id:

View File

@ -8,7 +8,8 @@ class ChatAction(EventBuilder):
""" """
Represents an action in a chat (such as user joined, left, or new pin). Represents an action in a chat (such as user joined, left, or new pin).
""" """
def build(self, update): @staticmethod
def build(update):
if isinstance(update, types.UpdateChannelPinnedMessage) and update.id == 0: if isinstance(update, types.UpdateChannelPinnedMessage) and update.id == 0:
# Telegram does not always send # Telegram does not always send
# UpdateChannelPinnedMessage for new pins # UpdateChannelPinnedMessage for new pins
@ -78,7 +79,7 @@ class ChatAction(EventBuilder):
return return
event._entities = update._entities event._entities = update._entities
return self._filter_event(event) return event
class Event(EventCommon): class Event(EventCommon):
""" """

View File

@ -52,21 +52,25 @@ class EventBuilder(abc.ABC):
will be handled *except* those specified in ``chats`` will be handled *except* those specified in ``chats``
which will be ignored if ``blacklist_chats=True``. which will be ignored if ``blacklist_chats=True``.
""" """
self_id = None
def __init__(self, chats=None, blacklist_chats=False): def __init__(self, chats=None, blacklist_chats=False):
self.chats = chats self.chats = chats
self.blacklist_chats = blacklist_chats self.blacklist_chats = blacklist_chats
self._self_id = None self._self_id = None
@staticmethod
@abc.abstractmethod @abc.abstractmethod
def build(self, update): def build(update):
"""Builds an event for the given update if possible, or returns None""" """Builds an event for the given update if possible, or returns None"""
async def resolve(self, client): async def resolve(self, client):
"""Helper method to allow event builders to be resolved before usage""" """Helper method to allow event builders to be resolved before usage"""
self.chats = await _into_id_set(client, self.chats) self.chats = await _into_id_set(client, self.chats)
self._self_id = (await client.get_me(input_peer=True)).user_id if not EventBuilder.self_id:
EventBuilder.self_id = await client.get_peer_id('me')
def _filter_event(self, event): def filter(self, event):
""" """
If the ID of ``event._chat_peer`` isn't in the chats set (or it is If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns ``None``, otherwise the event. but the set is a blacklist) returns ``None``, otherwise the event.

View File

@ -7,7 +7,8 @@ class MessageDeleted(EventBuilder):
""" """
Event fired when one or more messages are deleted. Event fired when one or more messages are deleted.
""" """
def build(self, update): @staticmethod
def build(update):
if isinstance(update, types.UpdateDeleteMessages): if isinstance(update, types.UpdateDeleteMessages):
event = MessageDeleted.Event( event = MessageDeleted.Event(
deleted_ids=update.messages, deleted_ids=update.messages,
@ -22,7 +23,7 @@ class MessageDeleted(EventBuilder):
return return
event._entities = update._entities event._entities = update._entities
return self._filter_event(event) return event
class Event(EventCommon): class Event(EventCommon):
def __init__(self, deleted_ids, peer): def __init__(self, deleted_ids, peer):

View File

@ -8,7 +8,8 @@ class MessageEdited(NewMessage):
""" """
Event fired when a message has been edited. Event fired when a message has been edited.
""" """
def build(self, update): @staticmethod
def build(update):
if isinstance(update, (types.UpdateEditMessage, if isinstance(update, (types.UpdateEditMessage,
types.UpdateEditChannelMessage)): types.UpdateEditChannelMessage)):
event = MessageEdited.Event(update.message) event = MessageEdited.Event(update.message)
@ -16,7 +17,7 @@ class MessageEdited(NewMessage):
return return
event._entities = update._entities event._entities = update._entities
return self._message_filter_event(event) return event
class Event(NewMessage.Event): class Event(NewMessage.Event):
pass # Required if we want a different name for it pass # Required if we want a different name for it

View File

@ -18,7 +18,8 @@ class MessageRead(EventBuilder):
super().__init__(chats, blacklist_chats) super().__init__(chats, blacklist_chats)
self.inbox = inbox self.inbox = inbox
def build(self, update): @staticmethod
def build(update):
if isinstance(update, types.UpdateReadHistoryInbox): if isinstance(update, types.UpdateReadHistoryInbox):
event = MessageRead.Event(update.peer, update.max_id, False) event = MessageRead.Event(update.peer, update.max_id, False)
elif isinstance(update, types.UpdateReadHistoryOutbox): elif isinstance(update, types.UpdateReadHistoryOutbox):
@ -39,11 +40,14 @@ class MessageRead(EventBuilder):
else: else:
return return
event._entities = update._entities
return event
def filter(self, event):
if self.inbox == event.outbox: if self.inbox == event.outbox:
return return
event._entities = update._entities return super().filter(event)
return self._filter_event(event)
class Event(EventCommon): class Event(EventCommon):
""" """

View File

@ -75,7 +75,8 @@ class NewMessage(EventBuilder):
await super().resolve(client) await super().resolve(client)
self.from_users = await _into_id_set(client, self.from_users) self.from_users = await _into_id_set(client, self.from_users)
def build(self, update): @staticmethod
def build(update):
if isinstance(update, if isinstance(update,
(types.UpdateNewMessage, types.UpdateNewChannelMessage)): (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if not isinstance(update.message, types.Message): if not isinstance(update.message, types.Message):
@ -91,9 +92,9 @@ class NewMessage(EventBuilder):
# Note that to_id/from_id complement each other in private # Note that to_id/from_id complement each other in private
# messages, depending on whether the message was outgoing. # messages, depending on whether the message was outgoing.
to_id=types.PeerUser( to_id=types.PeerUser(
update.user_id if update.out else self._self_id update.user_id if update.out else EventBuilder.self_id
), ),
from_id=self._self_id if update.out else update.user_id, from_id=EventBuilder.self_id if update.out else update.user_id,
message=update.message, message=update.message,
date=update.date, date=update.date,
fwd_from=update.fwd_from, fwd_from=update.fwd_from,
@ -120,8 +121,6 @@ class NewMessage(EventBuilder):
else: else:
return return
event._entities = update._entities
# Make messages sent to ourselves outgoing unless they're forwarded. # Make messages sent to ourselves outgoing unless they're forwarded.
# This makes it consistent with official client's appearance. # This makes it consistent with official client's appearance.
ori = event.message ori = event.message
@ -129,9 +128,10 @@ class NewMessage(EventBuilder):
if ori.from_id == ori.to_id.user_id and not ori.fwd_from: if ori.from_id == ori.to_id.user_id and not ori.fwd_from:
event.message.out = True event.message.out = True
return self._message_filter_event(event) event._entities = update._entities
return event
def _message_filter_event(self, event): def filter(self, event):
if self._no_check: if self._no_check:
return event return event
@ -153,7 +153,7 @@ class NewMessage(EventBuilder):
return return
event.pattern_match = match event.pattern_match = match
return self._filter_event(event) return super().filter(event)
class Event(EventCommon): class Event(EventCommon):
""" """

View File

@ -25,6 +25,10 @@ class Raw(EventBuilder):
async def resolve(self, client): async def resolve(self, client):
pass pass
def build(self, update): @staticmethod
if not self.types or isinstance(update, self.types): def build(update):
return update return update
def filter(self, event):
if not self.types or isinstance(event, self.types):
return event

View File

@ -9,7 +9,8 @@ class UserUpdate(EventBuilder):
""" """
Represents an user update (gone online, offline, joined Telegram). Represents an user update (gone online, offline, joined Telegram).
""" """
def build(self, update): @staticmethod
def build(update):
if isinstance(update, types.UpdateUserStatus): if isinstance(update, types.UpdateUserStatus):
event = UserUpdate.Event(update.user_id, event = UserUpdate.Event(update.user_id,
status=update.status) status=update.status)
@ -17,7 +18,7 @@ class UserUpdate(EventBuilder):
return return
event._entities = update._entities event._entities = update._entities
return self._filter_event(event) return event
class Event(EventCommon): class Event(EventCommon):
""" """