Merge branch 'event-reusing'

This commit is contained in:
Lonami Exo 2018-07-11 11:31:46 +02:00
commit 4f5c6f1006
11 changed files with 86 additions and 49 deletions

View File

@ -1,5 +1,6 @@
import abc
import asyncio
import collections
import inspect
import logging
import platform
@ -258,6 +259,11 @@ class TelegramBaseClient(abc.ABC):
self._events_pending_resolve = []
self._event_resolve_lock = asyncio.Lock()
# Keep track of how many event builders there are for
# each type {type: count}. If there's at least one then
# the event will be built, and the same event be reused.
self._event_builders_count = collections.defaultdict(int)
# Default parse mode
self._parse_mode = markdown

View File

@ -90,6 +90,7 @@ class UpdateMethods(UserMethods):
event = events.Raw()
self._events_pending_resolve.append(event)
self._event_builders_count[type(event)] += 1
self._event_builders.append((event, callback))
def remove_event_handler(self, callback, event=None):
@ -108,6 +109,11 @@ class UpdateMethods(UserMethods):
i -= 1
ev, cb = self._event_builders[i]
if cb == callback and (not event or isinstance(ev, event)):
type_ev = type(ev)
self._event_builders_count[type_ev] -= 1
if not self._event_builders_count[type_ev]:
del self._event_builders_count[type_ev]
del self._event_builders[i]
found += 1
@ -257,28 +263,36 @@ class UpdateMethods(UserMethods):
self._events_pending_resolve.clear()
for builder, callback in self._event_builders:
event = builder.build(update)
if event:
if hasattr(event, '_set_client'):
event._set_client(self)
else:
event._client = self
# TODO We can improve this further
# If we had a way to get all event builders for
# a type instead looping over them all always.
built = {builder: builder.build(update)
for builder in self._event_builders_count}
event.original_update = update
try:
await callback(event)
except events.StopPropagation:
__log__.debug(
"Event handler '{}' stopped chain of "
"propagation for event {}."
.format(callback.__name__,
type(event).__name__)
)
break
except Exception:
__log__.exception('Unhandled exception on {}'
.format(callback.__name__))
for builder, callback in self._event_builders:
event = built[type(builder)]
if not event or not builder.filter(event):
continue
if hasattr(event, '_set_client'):
event._set_client(self)
else:
event._client = self
event.original_update = update
try:
await callback(event)
except events.StopPropagation:
__log__.debug(
"Event handler '{}' stopped chain of "
"propagation for event {}."
.format(callback.__name__,
type(event).__name__)
)
break
except Exception:
__log__.exception('Unhandled exception on {}'
.format(callback.__name__))
async def _handle_auto_reconnect(self):
# Upon reconnection, we want to send getState

View File

@ -40,7 +40,8 @@ class CallbackQuery(EventBuilder):
else:
raise TypeError('Invalid data type given')
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, (types.UpdateBotCallbackQuery,
types.UpdateInlineBotCallbackQuery)):
event = CallbackQuery.Event(update)
@ -48,9 +49,9 @@ class CallbackQuery(EventBuilder):
return
event._entities = update._entities
return self._filter_event(event)
return event
def _filter_event(self, event):
def filter(self, event):
if self.chats is not None:
inside = event.query.chat_instance in self.chats
if event.chat_id:

View File

@ -8,7 +8,8 @@ class ChatAction(EventBuilder):
"""
Represents an action in a chat (such as user joined, left, or new pin).
"""
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, types.UpdateChannelPinnedMessage) and update.id == 0:
# Telegram does not always send
# UpdateChannelPinnedMessage for new pins
@ -78,7 +79,7 @@ class ChatAction(EventBuilder):
return
event._entities = update._entities
return self._filter_event(event)
return event
class Event(EventCommon):
"""

View File

@ -52,21 +52,25 @@ class EventBuilder(abc.ABC):
will be handled *except* those specified in ``chats``
which will be ignored if ``blacklist_chats=True``.
"""
self_id = None
def __init__(self, chats=None, blacklist_chats=False):
self.chats = chats
self.blacklist_chats = blacklist_chats
self._self_id = None
@staticmethod
@abc.abstractmethod
def build(self, update):
def build(update):
"""Builds an event for the given update if possible, or returns None"""
async def resolve(self, client):
"""Helper method to allow event builders to be resolved before usage"""
self.chats = await _into_id_set(client, self.chats)
self._self_id = (await client.get_me(input_peer=True)).user_id
if not EventBuilder.self_id:
EventBuilder.self_id = await client.get_peer_id('me')
def _filter_event(self, event):
def filter(self, event):
"""
If the ID of ``event._chat_peer`` isn't in the chats set (or it is
but the set is a blacklist) returns ``None``, otherwise the event.

View File

@ -7,7 +7,8 @@ class MessageDeleted(EventBuilder):
"""
Event fired when one or more messages are deleted.
"""
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, types.UpdateDeleteMessages):
event = MessageDeleted.Event(
deleted_ids=update.messages,
@ -22,7 +23,7 @@ class MessageDeleted(EventBuilder):
return
event._entities = update._entities
return self._filter_event(event)
return event
class Event(EventCommon):
def __init__(self, deleted_ids, peer):

View File

@ -8,7 +8,8 @@ class MessageEdited(NewMessage):
"""
Event fired when a message has been edited.
"""
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, (types.UpdateEditMessage,
types.UpdateEditChannelMessage)):
event = MessageEdited.Event(update.message)
@ -16,7 +17,7 @@ class MessageEdited(NewMessage):
return
event._entities = update._entities
return self._message_filter_event(event)
return event
class Event(NewMessage.Event):
pass # Required if we want a different name for it

View File

@ -18,7 +18,8 @@ class MessageRead(EventBuilder):
super().__init__(chats, blacklist_chats)
self.inbox = inbox
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, types.UpdateReadHistoryInbox):
event = MessageRead.Event(update.peer, update.max_id, False)
elif isinstance(update, types.UpdateReadHistoryOutbox):
@ -39,11 +40,14 @@ class MessageRead(EventBuilder):
else:
return
event._entities = update._entities
return event
def filter(self, event):
if self.inbox == event.outbox:
return
event._entities = update._entities
return self._filter_event(event)
return super().filter(event)
class Event(EventCommon):
"""

View File

@ -75,7 +75,8 @@ class NewMessage(EventBuilder):
await super().resolve(client)
self.from_users = await _into_id_set(client, self.from_users)
def build(self, update):
@staticmethod
def build(update):
if isinstance(update,
(types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if not isinstance(update.message, types.Message):
@ -91,9 +92,9 @@ class NewMessage(EventBuilder):
# Note that to_id/from_id complement each other in private
# messages, depending on whether the message was outgoing.
to_id=types.PeerUser(
update.user_id if update.out else self._self_id
update.user_id if update.out else EventBuilder.self_id
),
from_id=self._self_id if update.out else update.user_id,
from_id=EventBuilder.self_id if update.out else update.user_id,
message=update.message,
date=update.date,
fwd_from=update.fwd_from,
@ -120,8 +121,6 @@ class NewMessage(EventBuilder):
else:
return
event._entities = update._entities
# Make messages sent to ourselves outgoing unless they're forwarded.
# This makes it consistent with official client's appearance.
ori = event.message
@ -129,9 +128,10 @@ class NewMessage(EventBuilder):
if ori.from_id == ori.to_id.user_id and not ori.fwd_from:
event.message.out = True
return self._message_filter_event(event)
event._entities = update._entities
return event
def _message_filter_event(self, event):
def filter(self, event):
if self._no_check:
return event
@ -153,7 +153,7 @@ class NewMessage(EventBuilder):
return
event.pattern_match = match
return self._filter_event(event)
return super().filter(event)
class Event(EventCommon):
"""

View File

@ -25,6 +25,10 @@ class Raw(EventBuilder):
async def resolve(self, client):
pass
def build(self, update):
if not self.types or isinstance(update, self.types):
return update
@staticmethod
def build(update):
return update
def filter(self, event):
if not self.types or isinstance(event, self.types):
return event

View File

@ -9,7 +9,8 @@ class UserUpdate(EventBuilder):
"""
Represents an user update (gone online, offline, joined Telegram).
"""
def build(self, update):
@staticmethod
def build(update):
if isinstance(update, types.UpdateUserStatus):
event = UserUpdate.Event(update.user_id,
status=update.status)
@ -17,7 +18,7 @@ class UserUpdate(EventBuilder):
return
event._entities = update._entities
return self._filter_event(event)
return event
class Event(EventCommon):
"""