mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-31 16:07:44 +03:00 
			
		
		
		
	Refactor code to fetch missing entities once again
This is another attempt at reducing CPU usage similar to:
    1b6b4a57d9
In addition it simplifies some of the code and opens up new
ideas for the state cache as well.
			
			
This commit is contained in:
		
							parent
							
								
									c12c65f728
								
							
						
					
					
						commit
						22124b5ced
					
				|  | @ -282,25 +282,22 @@ class UpdateMethods(UserMethods): | ||||||
|         self._dispatching_updates_queue.clear() |         self._dispatching_updates_queue.clear() | ||||||
| 
 | 
 | ||||||
|     async def _dispatch_update(self, update, channel_id, pts_date): |     async def _dispatch_update(self, update, channel_id, pts_date): | ||||||
|  |         if not self._entity_cache.ensure_cached(update): | ||||||
|  |             await self._get_difference(update, channel_id, pts_date) | ||||||
|  | 
 | ||||||
|         built = EventBuilderDict(self, update) |         built = EventBuilderDict(self, update) | ||||||
|         if self._conversations: |         if self._conversations: | ||||||
|             for conv in self._conversations.values(): |             for conv in self._conversations.values(): | ||||||
|                 ev = built[events.NewMessage] |                 ev = built[events.NewMessage] | ||||||
|                 if ev: |                 if ev: | ||||||
|                     if not ev._load_entities(): |  | ||||||
|                         await ev._get_difference(channel_id, pts_date) |  | ||||||
|                     conv._on_new_message(ev) |                     conv._on_new_message(ev) | ||||||
| 
 | 
 | ||||||
|                 ev = built[events.MessageEdited] |                 ev = built[events.MessageEdited] | ||||||
|                 if ev: |                 if ev: | ||||||
|                     if not ev._load_entities(): |  | ||||||
|                         await ev._get_difference(channel_id, pts_date) |  | ||||||
|                     conv._on_edit(ev) |                     conv._on_edit(ev) | ||||||
| 
 | 
 | ||||||
|                 ev = built[events.MessageRead] |                 ev = built[events.MessageRead] | ||||||
|                 if ev: |                 if ev: | ||||||
|                     if not ev._load_entities(): |  | ||||||
|                         await ev._get_difference(channel_id, pts_date) |  | ||||||
|                     conv._on_read(ev) |                     conv._on_read(ev) | ||||||
| 
 | 
 | ||||||
|                 if conv._custom: |                 if conv._custom: | ||||||
|  | @ -318,14 +315,6 @@ class UpdateMethods(UserMethods): | ||||||
|                 continue |                 continue | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|                 # Although needing to do this constantly is annoying and |  | ||||||
|                 # error-prone, this part is somewhat hot, and always doing |  | ||||||
|                 # `await` for `check_entities_and_get_difference` causes |  | ||||||
|                 # unnecessary work. So we need to call a function that |  | ||||||
|                 # doesn't cause a task switch. |  | ||||||
|                 if isinstance(event, EventCommon) and not event._load_entities(): |  | ||||||
|                     await event._get_difference(channel_id, pts_date) |  | ||||||
| 
 |  | ||||||
|                 await callback(event) |                 await callback(event) | ||||||
|             except errors.AlreadyInConversationError: |             except errors.AlreadyInConversationError: | ||||||
|                 name = getattr(callback, '__name__', repr(callback)) |                 name = getattr(callback, '__name__', repr(callback)) | ||||||
|  | @ -344,6 +333,46 @@ class UpdateMethods(UserMethods): | ||||||
|                 self._log[__name__].exception('Unhandled exception on %s', |                 self._log[__name__].exception('Unhandled exception on %s', | ||||||
|                                               name) |                                               name) | ||||||
| 
 | 
 | ||||||
|  |     async def _get_difference(self, update, channel_id, pts_date): | ||||||
|  |         """ | ||||||
|  |         Get the difference for this `channel_id` if any, then load entities. | ||||||
|  | 
 | ||||||
|  |         Calls :tl:`updates.getDifference`, which fills the entities cache | ||||||
|  |         (always done by `__call__`) and lets us know about the full entities. | ||||||
|  |         """ | ||||||
|  |         # Fetch since the last known pts/date before this update arrived, | ||||||
|  |         # in order to fetch this update at full, including its entities. | ||||||
|  |         self._log[__name__].debug('Getting difference for entities ' | ||||||
|  |                                   'for %r', update.__class__) | ||||||
|  |         if channel_id: | ||||||
|  |             try: | ||||||
|  |                 where = await self.get_input_entity(channel_id) | ||||||
|  |             except ValueError: | ||||||
|  |                 return | ||||||
|  | 
 | ||||||
|  |             result = await self(functions.updates.GetChannelDifferenceRequest( | ||||||
|  |                 channel=where, | ||||||
|  |                 filter=types.ChannelMessagesFilterEmpty(), | ||||||
|  |                 pts=pts_date,  # just pts | ||||||
|  |                 limit=100, | ||||||
|  |                 force=True | ||||||
|  |             )) | ||||||
|  |         else: | ||||||
|  |             result = await self(functions.updates.GetDifferenceRequest( | ||||||
|  |                 pts=pts_date[0], | ||||||
|  |                 date=pts_date[1], | ||||||
|  |                 qts=0 | ||||||
|  |             )) | ||||||
|  | 
 | ||||||
|  |         if isinstance(result, (types.updates.Difference, | ||||||
|  |                                types.updates.DifferenceSlice, | ||||||
|  |                                types.updates.ChannelDifference, | ||||||
|  |                                types.updates.ChannelDifferenceTooLong)): | ||||||
|  |             update._entities.update({ | ||||||
|  |                 utils.get_peer_id(x): x for x in | ||||||
|  |                 itertools.chain(result.users, result.chats) | ||||||
|  |             }) | ||||||
|  | 
 | ||||||
|     async def _handle_auto_reconnect(self): |     async def _handle_auto_reconnect(self): | ||||||
|         # TODO Catch-up |         # TODO Catch-up | ||||||
|         return |         return | ||||||
|  | @ -398,6 +427,7 @@ class EventBuilderDict: | ||||||
|             if isinstance(event, EventCommon): |             if isinstance(event, EventCommon): | ||||||
|                 event.original_update = self.update |                 event.original_update = self.update | ||||||
|                 event._set_client(self.client) |                 event._set_client(self.client) | ||||||
|  |                 event._load_entities() | ||||||
|             elif event: |             elif event: | ||||||
|                 event._client = self.client |                 event._client = self.client | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,7 +1,57 @@ | ||||||
| import itertools | import itertools | ||||||
|  | 
 | ||||||
| from . import utils | from . import utils | ||||||
| from .tl import types | from .tl import types | ||||||
| 
 | 
 | ||||||
|  | # Which updates have the following fields? | ||||||
|  | _has_user_id = [] | ||||||
|  | _has_chat_id = [] | ||||||
|  | _has_channel_id = [] | ||||||
|  | _has_peer = [] | ||||||
|  | _has_dialog_peer = [] | ||||||
|  | _has_message = [] | ||||||
|  | 
 | ||||||
|  | # Note: We don't bother checking for some rare: | ||||||
|  | # * `UpdateChatParticipantAdd.inviter_id` integer. | ||||||
|  | # * `UpdateNotifySettings.peer` dialog peer. | ||||||
|  | # * `UpdatePinnedDialogs.order` list of dialog peers. | ||||||
|  | # * `UpdateReadMessagesContents.messages` list of messages. | ||||||
|  | # * `UpdateChatParticipants.participants` list of participants. | ||||||
|  | # | ||||||
|  | # There are also some uninteresting `update.message` of type string. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _fill(): | ||||||
|  |     for name in dir(types): | ||||||
|  |         update = getattr(types, name) | ||||||
|  |         if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e: | ||||||
|  |             cid = update.CONSTRUCTOR_ID | ||||||
|  |             doc = update.__init__.__doc__ or '' | ||||||
|  |             if ':param int user_id:' in doc: | ||||||
|  |                 _has_user_id.append(cid) | ||||||
|  |             if ':param int chat_id:' in doc: | ||||||
|  |                 _has_chat_id.append(cid) | ||||||
|  |             if ':param int channel_id:' in doc: | ||||||
|  |                 _has_channel_id.append(cid) | ||||||
|  |             if ':param TypePeer peer:' in doc: | ||||||
|  |                 _has_peer.append(cid) | ||||||
|  |             if ':param TypeDialogPeer peer:' in doc: | ||||||
|  |                 _has_dialog_peer.append(cid) | ||||||
|  |             if ':param TypeMessage message:' in doc: | ||||||
|  |                 _has_message.append(cid) | ||||||
|  | 
 | ||||||
|  |     # Future-proof check: if the documentation format ever changes | ||||||
|  |     # then we won't be able to pick the update types we are interested | ||||||
|  |     # in, so we must make sure we have at least an update for each field | ||||||
|  |     # which likely means we are doing it right. | ||||||
|  |     if not all((_has_user_id, _has_chat_id, _has_channel_id, | ||||||
|  |                 _has_peer, _has_dialog_peer)): | ||||||
|  |         raise RuntimeError('FIXME: Did the generated docs or updates change?') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # We use a function to avoid cluttering the globals (with name/update/cid/doc) | ||||||
|  | _fill() | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class EntityCache: | class EntityCache: | ||||||
|     """ |     """ | ||||||
|  | @ -46,3 +96,51 @@ class EntityCache: | ||||||
|                 return result |                 return result | ||||||
| 
 | 
 | ||||||
|         raise KeyError('No cached entity for the given key') |         raise KeyError('No cached entity for the given key') | ||||||
|  | 
 | ||||||
|  |     def ensure_cached( | ||||||
|  |             self, | ||||||
|  |             update, | ||||||
|  |             has_user_id=frozenset(_has_user_id), | ||||||
|  |             has_channel_id=frozenset(_has_channel_id), | ||||||
|  |             has_peer=frozenset(_has_peer + _has_dialog_peer), | ||||||
|  |             has_message=frozenset(_has_message) | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Ensures that all the relevant entities in the given update are cached. | ||||||
|  |         """ | ||||||
|  |         # This method is called pretty often and we want it to have the lowest | ||||||
|  |         # overhead possible. For that, we avoid `isinstance` and constantly | ||||||
|  |         # getting attributes out of `types.` by "caching" the constructor IDs | ||||||
|  |         # in sets inside the arguments, and using local variables. | ||||||
|  |         dct = self.__dict__ | ||||||
|  |         cid = update.CONSTRUCTOR_ID | ||||||
|  |         if cid in has_user_id and \ | ||||||
|  |                 update.user_id not in dct: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         if cid in _has_chat_id and \ | ||||||
|  |                 utils.get_peer_id(types.PeerChat(update.chat_id)) not in dct: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         if cid in has_channel_id and \ | ||||||
|  |                 utils.get_peer_id(types.PeerChannel(update.channel_id)) not in dct: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         if cid in has_peer and \ | ||||||
|  |                 utils.get_peer_id(update.peer) not in dct: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         if cid in has_message: | ||||||
|  |             x = update.message | ||||||
|  |             y = getattr(x, 'to_id', None)  # handle MessageEmpty | ||||||
|  |             if y and utils.get_peer_id(y) not in dct: | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             y = getattr(x, 'from_id', None) | ||||||
|  |             if y and y not in dct: | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             # We don't quite worry about entities anywhere else. | ||||||
|  |             # This is enough. | ||||||
|  | 
 | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  | @ -175,51 +175,6 @@ class EventCommon(ChatGetter, abc.ABC): | ||||||
|         self._chat, self._input_chat = self._get_entity_pair(self.chat_id) |         self._chat, self._input_chat = self._get_entity_pair(self.chat_id) | ||||||
|         return self._input_chat is not None |         return self._input_chat is not None | ||||||
| 
 | 
 | ||||||
|     async def _get_difference(self, channel_id, pts_date): |  | ||||||
|         """ |  | ||||||
|         Get the difference for this `channel_id` if any, then load entities. |  | ||||||
| 
 |  | ||||||
|         Calls :tl:`updates.getDifference`, which fills the entities cache |  | ||||||
|         (always done by `__call__`) and lets us know about the full entities. |  | ||||||
|         """ |  | ||||||
|         # Fetch since the last known pts/date before this update arrived, |  | ||||||
|         # in order to fetch this update at full, including its entities. |  | ||||||
|         self.client._log[__name__].debug('Getting difference for entities') |  | ||||||
|         if channel_id: |  | ||||||
|             try: |  | ||||||
|                 where = await self.client.get_input_entity(channel_id) |  | ||||||
|             except ValueError: |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             result = await self.client(functions.updates.GetChannelDifferenceRequest( |  | ||||||
|                 channel=where, |  | ||||||
|                 filter=types.ChannelMessagesFilterEmpty(), |  | ||||||
|                 pts=pts_date,  # just pts |  | ||||||
|                 limit=100, |  | ||||||
|                 force=True |  | ||||||
|             )) |  | ||||||
|         else: |  | ||||||
|             result = await self.client(functions.updates.GetDifferenceRequest( |  | ||||||
|                 pts=pts_date[0], |  | ||||||
|                 date=pts_date[1], |  | ||||||
|                 qts=0 |  | ||||||
|             )) |  | ||||||
| 
 |  | ||||||
|         if isinstance(result, (types.updates.Difference, |  | ||||||
|                                types.updates.DifferenceSlice, |  | ||||||
|                                types.updates.ChannelDifference, |  | ||||||
|                                types.updates.ChannelDifferenceTooLong)): |  | ||||||
|             self.original_update._entities.update({ |  | ||||||
|                 utils.get_peer_id(x): x for x in |  | ||||||
|                 itertools.chain(result.users, result.chats) |  | ||||||
|             }) |  | ||||||
| 
 |  | ||||||
|         if not self._load_entities(): |  | ||||||
|             self.client._log[__name__].info( |  | ||||||
|                 'Could not find all entities for update.pts = %s', |  | ||||||
|                 getattr(self.original_update, 'pts', None) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|     @property |     @property | ||||||
|     def client(self): |     def client(self): | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user