mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Implement iterator over message by IDs
This commit is contained in:
		
							parent
							
								
									60606b9994
								
							
						
					
					
						commit
						6d6c1917bc
					
				| 
						 | 
					@ -228,7 +228,8 @@ class _MessagesIter(RequestIter):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _IDsIter(RequestIter):
 | 
					class _IDsIter(RequestIter):
 | 
				
			||||||
    async def _init(self, entity, from_user, ids):
 | 
					    async def _init(self, entity, ids):
 | 
				
			||||||
 | 
					        # TODO We never actually split IDs in chunks, but maybe we should
 | 
				
			||||||
        if not utils.is_list_like(ids):
 | 
					        if not utils.is_list_like(ids):
 | 
				
			||||||
            self.ids = [ids]
 | 
					            self.ids = [ids]
 | 
				
			||||||
        elif not ids:
 | 
					        elif not ids:
 | 
				
			||||||
| 
						 | 
					@ -238,10 +239,52 @@ class _IDsIter(RequestIter):
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.ids = ids
 | 
					            self.ids = ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        raise NotImplementedError
 | 
					        if entity:
 | 
				
			||||||
 | 
					            entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.total = len(ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from_id = None  # By default, no need to validate from_id
 | 
				
			||||||
 | 
					        if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                r = await self.client(
 | 
				
			||||||
 | 
					                    functions.channels.GetMessagesRequest(entity, ids))
 | 
				
			||||||
 | 
					            except errors.MessageIdsEmptyError:
 | 
				
			||||||
 | 
					                # All IDs were invalid, use a dummy result
 | 
				
			||||||
 | 
					                r = types.messages.MessagesNotModified(len(ids))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            r = await self.client(functions.messages.GetMessagesRequest(ids))
 | 
				
			||||||
 | 
					            if entity:
 | 
				
			||||||
 | 
					                from_id = await self.client.get_peer_id(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(r, types.messages.MessagesNotModified):
 | 
				
			||||||
 | 
					            self.buffer = [None] * len(ids)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram seems to return the messages in the order in which
 | 
				
			||||||
 | 
					        # we asked them for, so we don't need to check it ourselves,
 | 
				
			||||||
 | 
					        # unless some messages were invalid in which case Telegram
 | 
				
			||||||
 | 
					        # may decide to not send them at all.
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # The passed message IDs may not belong to the desired entity
 | 
				
			||||||
 | 
					        # since the user can enter arbitrary numbers which can belong to
 | 
				
			||||||
 | 
					        # arbitrary chats. Validate these unless ``from_id is None``.
 | 
				
			||||||
 | 
					        result = []
 | 
				
			||||||
 | 
					        for message in r.messages:
 | 
				
			||||||
 | 
					            if isinstance(message, types.MessageEmpty) or (
 | 
				
			||||||
 | 
					                    from_id and message.chat_id != from_id):
 | 
				
			||||||
 | 
					                result.append(None)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                message._finish_init(self.client, entities, entity)
 | 
				
			||||||
 | 
					                result.append(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.buffer = result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _load_next_chunk(self):
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        return []  # no next chunk, all done in init
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
					class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
| 
						 | 
					@ -850,51 +893,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # region Private methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def _iter_ids(self, entity, ids, total):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Special case for `iter_messages` when it should only fetch some IDs.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if total:
 | 
					 | 
				
			||||||
            total[0] = len(ids)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from_id = None  # By default, no need to validate from_id
 | 
					 | 
				
			||||||
        if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                r = await self(
 | 
					 | 
				
			||||||
                    functions.channels.GetMessagesRequest(entity, ids))
 | 
					 | 
				
			||||||
            except errors.MessageIdsEmptyError:
 | 
					 | 
				
			||||||
                # All IDs were invalid, use a dummy result
 | 
					 | 
				
			||||||
                r = types.messages.MessagesNotModified(len(ids))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            r = await self(functions.messages.GetMessagesRequest(ids))
 | 
					 | 
				
			||||||
            if entity:
 | 
					 | 
				
			||||||
                from_id = utils.get_peer_id(entity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(r, types.messages.MessagesNotModified):
 | 
					 | 
				
			||||||
            for _ in ids:
 | 
					 | 
				
			||||||
                await yield_(None)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        entities = {utils.get_peer_id(x): x
 | 
					 | 
				
			||||||
                    for x in itertools.chain(r.users, r.chats)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Telegram seems to return the messages in the order in which
 | 
					 | 
				
			||||||
        # we asked them for, so we don't need to check it ourselves,
 | 
					 | 
				
			||||||
        # unless some messages were invalid in which case Telegram
 | 
					 | 
				
			||||||
        # may decide to not send them at all.
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        # The passed message IDs may not belong to the desired entity
 | 
					 | 
				
			||||||
        # since the user can enter arbitrary numbers which can belong to
 | 
					 | 
				
			||||||
        # arbitrary chats. Validate these unless ``from_id is None``.
 | 
					 | 
				
			||||||
        for message in r.messages:
 | 
					 | 
				
			||||||
            if isinstance(message, types.MessageEmpty) or (
 | 
					 | 
				
			||||||
                    from_id and message.chat_id != from_id):
 | 
					 | 
				
			||||||
                await yield_(None)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
                await yield_(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # endregion
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,6 +48,8 @@ class RequestIter(abc.ABC):
 | 
				
			||||||
        to avoid forgetting or misspelling any of them.
 | 
					        to avoid forgetting or misspelling any of them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        This method may ``raise StopAsyncIteration`` if it cannot continue.
 | 
					        This method may ``raise StopAsyncIteration`` if it cannot continue.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method may actually fill the initial buffer if it needs to.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def __anext__(self):
 | 
					    async def __anext__(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user