diff --git a/telethon/client/updates.py b/telethon/client/updates.py index ded17695..871e6a1f 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -213,13 +213,17 @@ class UpdateMethods(UserMethods): self._state_cache.update(update) def _process_update(self, update, entities=None): - update._channel_id = self._state_cache.get_channel_id(update) - update._pts_date = self._state_cache[update._channel_id] update._entities = entities or {} + + # This part is somewhat hot so we don't bother patching + # update with channel ID/its state. Instead we just pass + # arguments which is faster. + channel_id = self._state_cache.get_channel_id(update) + args = (update, channel_id, self._state_cache[channel_id]) if self._updates_queue is None: - self._loop.create_task(self._dispatch_update(update)) + self._loop.create_task(self._dispatch_update(*args)) else: - self._updates_queue.put_nowait(update) + self._updates_queue.put_nowait(args) if not self._dispatching_updates_queue.is_set(): self._dispatching_updates_queue.set() self._loop.create_task(self._dispatch_queue_updates()) @@ -273,25 +277,37 @@ class UpdateMethods(UserMethods): async def _dispatch_queue_updates(self): while not self._updates_queue.empty(): - await self._dispatch_update(self._updates_queue.get_nowait()) + await self._dispatch_update(*self._updates_queue.get_nowait()) self._dispatching_updates_queue.clear() - async def _dispatch_update(self, update): + async def _dispatch_update(self, update, channel_id, pts_date): built = EventBuilderDict(self, update) if self._conversations: for conv in self._conversations.values(): - if await built.get(events.NewMessage): - conv._on_new_message(built[events.NewMessage]) - if await built.get(events.MessageEdited): - conv._on_edit(built[events.MessageEdited]) - if await built.get(events.MessageRead): - conv._on_read(built[events.MessageRead]) + ev = built[events.NewMessage] + if ev: + if not ev._load_entities(): + await ev._get_difference(channel_id, pts_date) + conv._on_new_message(ev) + + ev = built[events.MessageEdited] + if ev: + if not ev._load_entities(): + await ev._get_difference(channel_id, pts_date) + conv._on_edit(ev) + + ev = built[events.MessageRead] + if ev: + if not ev._load_entities(): + await ev._get_difference(channel_id, pts_date) + conv._on_read(ev) + if conv._custom: - await conv._check_custom(built) + await conv._check_custom(built, channel_id, pts_date) for builder, callback in self._event_builders: - event = await built.get(type(builder)) + event = built[type(builder)] if not event: continue @@ -302,6 +318,14 @@ class UpdateMethods(UserMethods): continue try: + # Although needing to do this constantly is annoying and + # error-prone, this part is somewhat hot, and always doing + # `await` for `check_entities_and_get_difference` causes + # unnecessary work. So we need to call a function that + # doesn't cause a task switch. + if not event._load_entities(): + await event._get_difference(channel_id, pts_date) + await callback(event) except errors.AlreadyInConversationError: name = getattr(callback, '__name__', repr(callback)) @@ -367,9 +391,6 @@ class EventBuilderDict: self.update = update def __getitem__(self, builder): - return self.__dict__[builder] - - async def get(self, builder): try: return self.__dict__[builder] except KeyError: @@ -377,54 +398,7 @@ class EventBuilderDict: if isinstance(event, EventCommon): event.original_update = self.update event._set_client(self.client) - if not event._load_entities(): - await self.get_difference() - if not event._load_entities(): - self.client._log[__name__].info( - 'Could not find all entities for update.pts = %s', - getattr(self.update, 'pts', None) - ) elif event: - # Actually a :tl:`Update`, not much processing to do event._client = self.client return event - - async def get_difference(self): - """ - Calls :tl:`updates.getDifference`, which fills the entities cache - (always done by `__call__`) and lets us know about the full entities. - """ - # Fetch since the last known pts/date before this update arrived, - # in order to fetch this update at full, including its entities. - self.client._log[__name__].debug('Getting difference for entities') - if self.update._channel_id: - pts = self.update._pts_date - try: - where = await self.client.get_input_entity(self.update._channel_id) - except ValueError: - return - - result = await self.client(functions.updates.GetChannelDifferenceRequest( - channel=where, - filter=types.ChannelMessagesFilterEmpty(), - pts=pts, - limit=100, - force=True - )) - else: - pts, date = self.update._pts_date - result = await self.client(functions.updates.GetDifferenceRequest( - pts=pts - 1, - date=date, - qts=0 - )) - - if isinstance(result, (types.updates.Difference, - types.updates.DifferenceSlice, - types.updates.ChannelDifference, - types.updates.ChannelDifferenceTooLong)): - self.update._entities.update({ - utils.get_peer_id(x): x for x in - itertools.chain(result.users, result.chats) - }) diff --git a/telethon/events/common.py b/telethon/events/common.py index 8cecb2a6..367c5f56 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -1,9 +1,10 @@ import abc import asyncio +import itertools import warnings from .. import utils -from ..tl import TLObject, types +from ..tl import TLObject, types, functions from ..tl.custom.chatgetter import ChatGetter @@ -174,6 +175,51 @@ class EventCommon(ChatGetter, abc.ABC): self._chat, self._input_chat = self._get_entity_pair(self.chat_id) return self._input_chat is not None + async def _get_difference(self, channel_id, pts_date): + """ + Get the difference for this `channel_id` if any, then load entities. + + Calls :tl:`updates.getDifference`, which fills the entities cache + (always done by `__call__`) and lets us know about the full entities. + """ + # Fetch since the last known pts/date before this update arrived, + # in order to fetch this update at full, including its entities. + self.client._log[__name__].debug('Getting difference for entities') + if channel_id: + try: + where = await self.client.get_input_entity(channel_id) + except ValueError: + return + + result = await self.client(functions.updates.GetChannelDifferenceRequest( + channel=where, + filter=types.ChannelMessagesFilterEmpty(), + pts=pts_date, # just pts + limit=100, + force=True + )) + else: + result = await self.client(functions.updates.GetDifferenceRequest( + pts=pts_date[0], + date=pts_date[1], + qts=0 + )) + + if isinstance(result, (types.updates.Difference, + types.updates.DifferenceSlice, + types.updates.ChannelDifference, + types.updates.ChannelDifferenceTooLong)): + self.original_update._entities.update({ + utils.get_peer_id(x): x for x in + itertools.chain(result.users, result.chats) + }) + + if not self._load_entities(): + self.client._log[__name__].info( + 'Could not find all entities for update.pts = %s', + getattr(self.original_update, 'pts', None) + ) + @property def client(self): """ diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py index d8d93b13..00d93939 100644 --- a/telethon/tl/custom/conversation.py +++ b/telethon/tl/custom/conversation.py @@ -294,10 +294,13 @@ class Conversation(ChatGetter): self._custom[counter] = (event, future) return await result() - async def _check_custom(self, built): + async def _check_custom(self, built, channel_id, pts_date): for i, (ev, fut) in self._custom.items(): ev_type = type(ev) if built[ev_type] and ev.filter(built[ev_type]): + if not ev._load_entities(): + await ev._get_difference(channel_id, pts_date) + fut.set_result(built[ev_type]) def _on_new_message(self, response):