diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 3d1a986c..b5d105a4 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -282,25 +282,22 @@ class UpdateMethods(UserMethods): self._dispatching_updates_queue.clear() async def _dispatch_update(self, update, channel_id, pts_date): + if not self._entity_cache.ensure_cached(update): + await self._get_difference(update, channel_id, pts_date) + built = EventBuilderDict(self, update) if self._conversations: for conv in self._conversations.values(): ev = built[events.NewMessage] if ev: - if not ev._load_entities(): - await ev._get_difference(channel_id, pts_date) conv._on_new_message(ev) ev = built[events.MessageEdited] if ev: - if not ev._load_entities(): - await ev._get_difference(channel_id, pts_date) conv._on_edit(ev) ev = built[events.MessageRead] if ev: - if not ev._load_entities(): - await ev._get_difference(channel_id, pts_date) conv._on_read(ev) if conv._custom: @@ -318,14 +315,6 @@ class UpdateMethods(UserMethods): continue try: - # Although needing to do this constantly is annoying and - # error-prone, this part is somewhat hot, and always doing - # `await` for `check_entities_and_get_difference` causes - # unnecessary work. So we need to call a function that - # doesn't cause a task switch. - if isinstance(event, EventCommon) and not event._load_entities(): - await event._get_difference(channel_id, pts_date) - await callback(event) except errors.AlreadyInConversationError: name = getattr(callback, '__name__', repr(callback)) @@ -344,6 +333,46 @@ class UpdateMethods(UserMethods): self._log[__name__].exception('Unhandled exception on %s', name) + async def _get_difference(self, update, channel_id, pts_date): + """ + Get the difference for this `channel_id` if any, then load entities. + + Calls :tl:`updates.getDifference`, which fills the entities cache + (always done by `__call__`) and lets us know about the full entities. + """ + # Fetch since the last known pts/date before this update arrived, + # in order to fetch this update at full, including its entities. + self._log[__name__].debug('Getting difference for entities ' + 'for %r', update.__class__) + if channel_id: + try: + where = await self.get_input_entity(channel_id) + except ValueError: + return + + result = await self(functions.updates.GetChannelDifferenceRequest( + channel=where, + filter=types.ChannelMessagesFilterEmpty(), + pts=pts_date, # just pts + limit=100, + force=True + )) + else: + result = await self(functions.updates.GetDifferenceRequest( + pts=pts_date[0], + date=pts_date[1], + qts=0 + )) + + if isinstance(result, (types.updates.Difference, + types.updates.DifferenceSlice, + types.updates.ChannelDifference, + types.updates.ChannelDifferenceTooLong)): + update._entities.update({ + utils.get_peer_id(x): x for x in + itertools.chain(result.users, result.chats) + }) + async def _handle_auto_reconnect(self): # TODO Catch-up return @@ -398,6 +427,7 @@ class EventBuilderDict: if isinstance(event, EventCommon): event.original_update = self.update event._set_client(self.client) + event._load_entities() elif event: event._client = self.client diff --git a/telethon/entitycache.py b/telethon/entitycache.py index a87d787b..49cb0f58 100644 --- a/telethon/entitycache.py +++ b/telethon/entitycache.py @@ -1,7 +1,57 @@ import itertools + from . import utils from .tl import types +# Which updates have the following fields? +_has_user_id = [] +_has_chat_id = [] +_has_channel_id = [] +_has_peer = [] +_has_dialog_peer = [] +_has_message = [] + +# Note: We don't bother checking for some rare: +# * `UpdateChatParticipantAdd.inviter_id` integer. +# * `UpdateNotifySettings.peer` dialog peer. +# * `UpdatePinnedDialogs.order` list of dialog peers. +# * `UpdateReadMessagesContents.messages` list of messages. +# * `UpdateChatParticipants.participants` list of participants. +# +# There are also some uninteresting `update.message` of type string. + + +def _fill(): + for name in dir(types): + update = getattr(types, name) + if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e: + cid = update.CONSTRUCTOR_ID + doc = update.__init__.__doc__ or '' + if ':param int user_id:' in doc: + _has_user_id.append(cid) + if ':param int chat_id:' in doc: + _has_chat_id.append(cid) + if ':param int channel_id:' in doc: + _has_channel_id.append(cid) + if ':param TypePeer peer:' in doc: + _has_peer.append(cid) + if ':param TypeDialogPeer peer:' in doc: + _has_dialog_peer.append(cid) + if ':param TypeMessage message:' in doc: + _has_message.append(cid) + + # Future-proof check: if the documentation format ever changes + # then we won't be able to pick the update types we are interested + # in, so we must make sure we have at least an update for each field + # which likely means we are doing it right. + if not all((_has_user_id, _has_chat_id, _has_channel_id, + _has_peer, _has_dialog_peer)): + raise RuntimeError('FIXME: Did the generated docs or updates change?') + + +# We use a function to avoid cluttering the globals (with name/update/cid/doc) +_fill() + class EntityCache: """ @@ -46,3 +96,51 @@ class EntityCache: return result raise KeyError('No cached entity for the given key') + + def ensure_cached( + self, + update, + has_user_id=frozenset(_has_user_id), + has_channel_id=frozenset(_has_channel_id), + has_peer=frozenset(_has_peer + _has_dialog_peer), + has_message=frozenset(_has_message) + ): + """ + Ensures that all the relevant entities in the given update are cached. + """ + # This method is called pretty often and we want it to have the lowest + # overhead possible. For that, we avoid `isinstance` and constantly + # getting attributes out of `types.` by "caching" the constructor IDs + # in sets inside the arguments, and using local variables. + dct = self.__dict__ + cid = update.CONSTRUCTOR_ID + if cid in has_user_id and \ + update.user_id not in dct: + return False + + if cid in _has_chat_id and \ + utils.get_peer_id(types.PeerChat(update.chat_id)) not in dct: + return False + + if cid in has_channel_id and \ + utils.get_peer_id(types.PeerChannel(update.channel_id)) not in dct: + return False + + if cid in has_peer and \ + utils.get_peer_id(update.peer) not in dct: + return False + + if cid in has_message: + x = update.message + y = getattr(x, 'to_id', None) # handle MessageEmpty + if y and utils.get_peer_id(y) not in dct: + return False + + y = getattr(x, 'from_id', None) + if y and y not in dct: + return False + + # We don't quite worry about entities anywhere else. + # This is enough. + + return True diff --git a/telethon/events/common.py b/telethon/events/common.py index 367c5f56..46345aa4 100644 --- a/telethon/events/common.py +++ b/telethon/events/common.py @@ -175,51 +175,6 @@ class EventCommon(ChatGetter, abc.ABC): self._chat, self._input_chat = self._get_entity_pair(self.chat_id) return self._input_chat is not None - async def _get_difference(self, channel_id, pts_date): - """ - Get the difference for this `channel_id` if any, then load entities. - - Calls :tl:`updates.getDifference`, which fills the entities cache - (always done by `__call__`) and lets us know about the full entities. - """ - # Fetch since the last known pts/date before this update arrived, - # in order to fetch this update at full, including its entities. - self.client._log[__name__].debug('Getting difference for entities') - if channel_id: - try: - where = await self.client.get_input_entity(channel_id) - except ValueError: - return - - result = await self.client(functions.updates.GetChannelDifferenceRequest( - channel=where, - filter=types.ChannelMessagesFilterEmpty(), - pts=pts_date, # just pts - limit=100, - force=True - )) - else: - result = await self.client(functions.updates.GetDifferenceRequest( - pts=pts_date[0], - date=pts_date[1], - qts=0 - )) - - if isinstance(result, (types.updates.Difference, - types.updates.DifferenceSlice, - types.updates.ChannelDifference, - types.updates.ChannelDifferenceTooLong)): - self.original_update._entities.update({ - utils.get_peer_id(x): x for x in - itertools.chain(result.users, result.chats) - }) - - if not self._load_entities(): - self.client._log[__name__].info( - 'Could not find all entities for update.pts = %s', - getattr(self.original_update, 'pts', None) - ) - @property def client(self): """