Pass all Updates when building events

This commit is contained in:
Lonami Exo 2019-06-30 16:32:18 +02:00
parent 71979e7b23
commit a7a7c4add2
12 changed files with 20 additions and 20 deletions

View File

@ -291,22 +291,22 @@ class UpdateMethods:
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}
for u in update.updates: for u in update.updates:
self._process_update(u, entities) self._process_update(u, update.updates, entities=entities)
elif isinstance(update, types.UpdateShort): elif isinstance(update, types.UpdateShort):
self._process_update(update.update) self._process_update(update.update, None)
else: else:
self._process_update(update) self._process_update(update, None)
self._state_cache.update(update) self._state_cache.update(update)
def _process_update(self: 'TelegramClient', update, entities=None): def _process_update(self: 'TelegramClient', update, others, entities=None):
update._entities = entities or {} update._entities = entities or {}
# This part is somewhat hot so we don't bother patching # This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass # update with channel ID/its state. Instead we just pass
# arguments which is faster. # arguments which is faster.
channel_id = self._state_cache.get_channel_id(update) channel_id = self._state_cache.get_channel_id(update)
args = (update, channel_id, self._state_cache[channel_id]) args = (update, others, channel_id, self._state_cache[channel_id])
if self._dispatching_updates_queue is None: if self._dispatching_updates_queue is None:
task = self._loop.create_task(self._dispatch_update(*args)) task = self._loop.create_task(self._dispatch_update(*args))
self._updates_queue.add(task) self._updates_queue.add(task)
@ -370,7 +370,7 @@ class UpdateMethods:
self._dispatching_updates_queue.clear() self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, channel_id, pts_date): async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, pts_date):
if not self._entity_cache.ensure_cached(update): if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are # We could add a lock to not fetch the same pts twice if we are
# already fetching it. However this does not happen in practice, # already fetching it. However this does not happen in practice,
@ -380,7 +380,7 @@ class UpdateMethods:
# For example, UpdateUserStatus or UpdateChatUserTyping. # For example, UpdateUserStatus or UpdateChatUserTyping.
await self._get_difference(update, channel_id, pts_date) await self._get_difference(update, channel_id, pts_date)
built = EventBuilderDict(self, update) built = EventBuilderDict(self, update, others)
for conv_set in self._conversations.values(): for conv_set in self._conversations.values():
for conv in conv_set: for conv in conv_set:
ev = built[events.NewMessage] ev = built[events.NewMessage]
@ -527,15 +527,16 @@ class EventBuilderDict:
""" """
Helper "dictionary" to return events from types and cache them. Helper "dictionary" to return events from types and cache them.
""" """
def __init__(self, client: 'TelegramClient', update): def __init__(self, client: 'TelegramClient', update, others):
self.client = client self.client = client
self.update = update self.update = update
self.others = others
def __getitem__(self, builder): def __getitem__(self, builder):
try: try:
return self.__dict__[builder] return self.__dict__[builder]
except KeyError: except KeyError:
event = self.__dict__[builder] = builder.build(self.update) event = self.__dict__[builder] = builder.build(self.update, self.others)
if isinstance(event, EventCommon): if isinstance(event, EventCommon):
event.original_update = self.update event.original_update = self.update
event._entities = self.update._entities event._entities = self.update._entities

View File

@ -45,7 +45,7 @@ class CallbackQuery(EventBuilder):
raise TypeError('Invalid data type given') raise TypeError('Invalid data type given')
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateBotCallbackQuery): if isinstance(update, types.UpdateBotCallbackQuery):
return cls.Event(update, update.peer, update.msg_id) return cls.Event(update, update.peer, update.msg_id)
elif isinstance(update, types.UpdateInlineBotCallbackQuery): elif isinstance(update, types.UpdateInlineBotCallbackQuery):

View File

@ -9,7 +9,7 @@ class ChatAction(EventBuilder):
Occurs whenever a user joins or leaves a chat, or a message is pinned. Occurs whenever a user joins or leaves a chat, or a message is pinned.
""" """
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateChannelPinnedMessage) and update.id == 0: if isinstance(update, types.UpdateChannelPinnedMessage) and update.id == 0:
# Telegram does not always send # Telegram does not always send
# UpdateChannelPinnedMessage for new pins # UpdateChannelPinnedMessage for new pins

View File

@ -77,7 +77,7 @@ class EventBuilder(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def build(cls, update): def build(cls, update, others=None):
""" """
Builds an event for the given update if possible, or returns None. Builds an event for the given update if possible, or returns None.

View File

@ -46,7 +46,7 @@ class InlineQuery(EventBuilder):
raise TypeError('Invalid pattern type given') raise TypeError('Invalid pattern type given')
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateBotInlineQuery): if isinstance(update, types.UpdateBotInlineQuery):
return cls.Event(update) return cls.Event(update)

View File

@ -25,7 +25,7 @@ class MessageDeleted(EventBuilder):
unless you intend on working with channels and super-groups only. unless you intend on working with channels and super-groups only.
""" """
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateDeleteMessages): if isinstance(update, types.UpdateDeleteMessages):
return cls.Event( return cls.Event(
deleted_ids=update.messages, deleted_ids=update.messages,

View File

@ -33,7 +33,7 @@ class MessageEdited(NewMessage):
not you). not you).
""" """
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, (types.UpdateEditMessage, if isinstance(update, (types.UpdateEditMessage,
types.UpdateEditChannelMessage)): types.UpdateEditChannelMessage)):
return cls.Event(update.message) return cls.Event(update.message)

View File

@ -20,7 +20,7 @@ class MessageRead(EventBuilder):
self.inbox = inbox self.inbox = inbox
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateReadHistoryInbox): if isinstance(update, types.UpdateReadHistoryInbox):
return cls.Event(update.peer, update.max_id, False) return cls.Event(update.peer, update.max_id, False)
elif isinstance(update, types.UpdateReadHistoryOutbox): elif isinstance(update, types.UpdateReadHistoryOutbox):

View File

@ -76,7 +76,7 @@ class NewMessage(EventBuilder):
self.from_users = await _into_id_set(client, self.from_users) self.from_users = await _into_id_set(client, self.from_users)
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, if isinstance(update,
(types.UpdateNewMessage, types.UpdateNewChannelMessage)): (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if not isinstance(update.message, types.Message): if not isinstance(update.message, types.Message):

View File

@ -32,7 +32,7 @@ class Raw(EventBuilder):
self.resolved = True self.resolved = True
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
return update return update
def filter(self, event): def filter(self, event):

View File

@ -12,7 +12,7 @@ class UserUpdate(EventBuilder):
Occurs whenever a user goes online, starts typing, etc. Occurs whenever a user goes online, starts typing, etc.
""" """
@classmethod @classmethod
def build(cls, update): def build(cls, update, others=None):
if isinstance(update, types.UpdateUserStatus): if isinstance(update, types.UpdateUserStatus):
return cls.Event(update.user_id, return cls.Event(update.user_id,
status=update.status) status=update.status)

View File

@ -551,7 +551,6 @@ class MTProtoSender:
self._log.debug('Handling update %s', message.obj.__class__.__name__) self._log.debug('Handling update %s', message.obj.__class__.__name__)
if self._update_callback: if self._update_callback:
print(message.obj.stringify())
self._update_callback(message.obj) self._update_callback(message.obj)
async def _handle_pong(self, message): async def _handle_pong(self, message):