mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Create a new RequestIter ABC to deal with iter methods
This should make it easier to maintain these methods, increase reusability, and get rid of the async_generator dependency. In the future, people could use this to more easily deal with raw API themselves.
This commit is contained in:
		
							parent
							
								
									1e4a12d2f7
								
							
						
					
					
						commit
						36eb1b1009
					
				| 
						 | 
					@ -2,13 +2,152 @@ import asyncio
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from async_generator import async_generator, yield_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .messageparse import MessageParseMethods
 | 
					from .messageparse import MessageParseMethods
 | 
				
			||||||
from .uploads import UploadMethods
 | 
					from .uploads import UploadMethods
 | 
				
			||||||
from .buttons import ButtonMethods
 | 
					from .buttons import ButtonMethods
 | 
				
			||||||
from .. import helpers, utils, errors
 | 
					from .. import helpers, utils, errors
 | 
				
			||||||
from ..tl import types, functions
 | 
					from ..tl import types, functions
 | 
				
			||||||
 | 
					from ..requestiter import RequestIter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _GetHistoryIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset):
 | 
				
			||||||
 | 
					        self.entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram doesn't like min_id/max_id. If these IDs are low enough
 | 
				
			||||||
 | 
					        # (starting from last_id - 100), the request will return nothing.
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # We can emulate their behaviour locally by setting offset = max_id
 | 
				
			||||||
 | 
					        # and simply stopping once we hit a message with ID <= min_id.
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            offset_id = max(offset_id, min_id)
 | 
				
			||||||
 | 
					            if offset_id and max_id:
 | 
				
			||||||
 | 
					                if max_id - offset_id <= 1:
 | 
				
			||||||
 | 
					                    raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not max_id:
 | 
				
			||||||
 | 
					                max_id = float('inf')
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            offset_id = max(offset_id, max_id)
 | 
				
			||||||
 | 
					            if offset_id and min_id:
 | 
				
			||||||
 | 
					                if offset_id - min_id <= 1:
 | 
				
			||||||
 | 
					                    raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            if offset_id:
 | 
				
			||||||
 | 
					                offset_id += 1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                offset_id = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if from_user:
 | 
				
			||||||
 | 
					            from_user = await self.client.get_input_entity(from_user)
 | 
				
			||||||
 | 
					            if not isinstance(from_user, (
 | 
				
			||||||
 | 
					                    types.InputPeerUser, types.InputPeerSelf)):
 | 
				
			||||||
 | 
					                from_user = None  # Ignore from_user unless it's a user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.request = functions.messages.GetHistoryRequest(
 | 
				
			||||||
 | 
					            peer=entity,
 | 
				
			||||||
 | 
					            limit=1,
 | 
				
			||||||
 | 
					            offset_date=offset_date,
 | 
				
			||||||
 | 
					            offset_id=offset_id,
 | 
				
			||||||
 | 
					            min_id=0,
 | 
				
			||||||
 | 
					            max_id=0,
 | 
				
			||||||
 | 
					            add_offset=add_offset,
 | 
				
			||||||
 | 
					            hash=0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.limit == 0:
 | 
				
			||||||
 | 
					            # No messages, but we still need to know the total message count
 | 
				
			||||||
 | 
					            result = await self.client(self.request)
 | 
				
			||||||
 | 
					            if isinstance(result, types.messages.MessagesNotModified):
 | 
				
			||||||
 | 
					                self.total = result.count
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.total = getattr(result, 'count', len(result.messages))
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # When going in reverse we need an offset of `-limit`, but we
 | 
				
			||||||
 | 
					        # also want to respect what the user passed, so add them together.
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            self.request.add_offset -= batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.wait_time is None:
 | 
				
			||||||
 | 
					            self.wait_time = 1 if self.limit > 3000 else 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram has a hard limit of 100.
 | 
				
			||||||
 | 
					        # We don't need to fetch 100 if the limit is less.
 | 
				
			||||||
 | 
					        self.batch_size = min(max(batch_size, 1), min(100, self.limit))
 | 
				
			||||||
 | 
					        self.add_offset = add_offset
 | 
				
			||||||
 | 
					        self.max_id = max_id
 | 
				
			||||||
 | 
					        self.min_id = min_id
 | 
				
			||||||
 | 
					        self.last_id = 0 if self.reverse else float('inf')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        result = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.request.limit = min(self.left, self.batch_size)
 | 
				
			||||||
 | 
					        if self.reverse and self.request.limit != self.batch_size:
 | 
				
			||||||
 | 
					            # Remember that we need -limit when going in reverse
 | 
				
			||||||
 | 
					            self.request.add_offset = self.add_offset - self.request.limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r = await self.client(self.request)
 | 
				
			||||||
 | 
					        self.total = getattr(r, 'count', len(r.messages))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        messages = reversed(r.messages) if self.reverse else r.messages
 | 
				
			||||||
 | 
					        for message in messages:
 | 
				
			||||||
 | 
					            if (isinstance(message, types.MessageEmpty)
 | 
				
			||||||
 | 
					                    or self.from_id and message.from_id != self.from_id):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # TODO We used to yield and return here (stopping the iterator)
 | 
				
			||||||
 | 
					            # How should we go around that here?
 | 
				
			||||||
 | 
					            if self.reverse:
 | 
				
			||||||
 | 
					                if message.id <= self.last_id or message.id >= self.max_id:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if message.id >= self.last_id or message.id <= self.min_id:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # There has been reports that on bad connections this method
 | 
				
			||||||
 | 
					            # was returning duplicated IDs sometimes. Using ``last_id``
 | 
				
			||||||
 | 
					            # is an attempt to avoid these duplicates, since the message
 | 
				
			||||||
 | 
					            # IDs are returned in descending order (or asc if reverse).
 | 
				
			||||||
 | 
					            self.last_id = message.id
 | 
				
			||||||
 | 
					            message._finish_init(self.client, entities, self.entity)
 | 
				
			||||||
 | 
					            result.append(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(r.messages) < self.request.limit:
 | 
				
			||||||
 | 
					            return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Find the first message that's not empty (in some rare cases
 | 
				
			||||||
 | 
					        # it can happen that the last message is :tl:`MessageEmpty`)
 | 
				
			||||||
 | 
					        last_message = None
 | 
				
			||||||
 | 
					        messages = r.messages if self.reverse else reversed(r.messages)
 | 
				
			||||||
 | 
					        for m in messages:
 | 
				
			||||||
 | 
					            if not isinstance(m, types.MessageEmpty):
 | 
				
			||||||
 | 
					                last_message = m
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO If it's None, we used to break (ending the iterator)
 | 
				
			||||||
 | 
					        # Similar case as the return above.
 | 
				
			||||||
 | 
					        if last_message is not None:
 | 
				
			||||||
 | 
					            # There are some cases where all the messages we get start
 | 
				
			||||||
 | 
					            # being empty. This can happen on migrated mega-groups if
 | 
				
			||||||
 | 
					            # the history was cleared, and we're using search. Telegram
 | 
				
			||||||
 | 
					            # acts incredibly weird sometimes. Messages are returned but
 | 
				
			||||||
 | 
					            # only "empty", not their contents. If this is the case we
 | 
				
			||||||
 | 
					            # should just give up since there won't be any new Message.
 | 
				
			||||||
 | 
					            self.request.offset_id = last_message.id
 | 
				
			||||||
 | 
					            self.request.offset_date = last_message.date
 | 
				
			||||||
 | 
					            if self.reverse:
 | 
				
			||||||
 | 
					                # We want to skip the one we already have
 | 
				
			||||||
 | 
					                self.request.offset_id += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
					class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
| 
						 | 
					@ -17,7 +156,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # region Message retrieval
 | 
					    # region Message retrieval
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					 | 
				
			||||||
    async def iter_messages(
 | 
					    async def iter_messages(
 | 
				
			||||||
            self, entity, limit=None, *, offset_date=None, offset_id=0,
 | 
					            self, entity, limit=None, *, offset_date=None, offset_id=0,
 | 
				
			||||||
            max_id=0, min_id=0, add_offset=0, search=None, filter=None,
 | 
					            max_id=0, min_id=0, add_offset=0, search=None, filter=None,
 | 
				
			||||||
| 
						 | 
					@ -133,6 +271,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
            an higher limit, so you're free to set the ``batch_size`` that
 | 
					            an higher limit, so you're free to set the ``batch_size`` that
 | 
				
			||||||
            you think may be good.
 | 
					            you think may be good.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        # TODO Handle global search
 | 
				
			||||||
 | 
					        # TODO Handle search
 | 
				
			||||||
 | 
					        # TODO Handle yield IDs
 | 
				
			||||||
 | 
					        return _GetHistoryIter(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            limit=limit,
 | 
				
			||||||
 | 
					            wait_time=wait_time,
 | 
				
			||||||
 | 
					            entity=entity,
 | 
				
			||||||
 | 
					            reverse=reverse,
 | 
				
			||||||
 | 
					            offset_id=offset_id,
 | 
				
			||||||
 | 
					            min_id=min_id,
 | 
				
			||||||
 | 
					            max_id=max_id,
 | 
				
			||||||
 | 
					            from_user=from_user,
 | 
				
			||||||
 | 
					            batch_size=batch_size,
 | 
				
			||||||
 | 
					            offset_date=offset_date,
 | 
				
			||||||
 | 
					            add_offset=add_offset
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # Note that entity being ``None`` is intended to get messages by
 | 
					        # Note that entity being ``None`` is intended to get messages by
 | 
				
			||||||
        # ID under no specific chat, and also to request a global search.
 | 
					        # ID under no specific chat, and also to request a global search.
 | 
				
			||||||
        if entity:
 | 
					        if entity:
 | 
				
			||||||
| 
						 | 
					@ -802,7 +957,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # region Private methods
 | 
					    # region Private methods
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					 | 
				
			||||||
    async def _iter_ids(self, entity, ids, total):
 | 
					    async def _iter_ids(self, entity, ids, total):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Special case for `iter_messages` when it should only fetch some IDs.
 | 
					        Special case for `iter_messages` when it should only fetch some IDs.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										99
									
								
								telethon/requestiter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								telethon/requestiter.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,99 @@
 | 
				
			||||||
 | 
					import abc
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RequestIter(abc.ABC):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Helper class to deal with requests that need offsets to iterate.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It has some facilities, such as automatically sleeping a desired
 | 
				
			||||||
 | 
					    amount of time between requests if needed (but not more).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Can be used synchronously if the event loop is not running and
 | 
				
			||||||
 | 
					    as an asynchronous iterator otherwise.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `limit` is the total amount of items that the iterator should return.
 | 
				
			||||||
 | 
					    This is handled on this base class, and will be always ``>= 0``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `left` will be reset every time the iterator is used and will indicate
 | 
				
			||||||
 | 
					    the amount of items that should be emitted left, so that subclasses can
 | 
				
			||||||
 | 
					    be more efficient and fetch only as many items as they need.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Iterators may be used with ``reversed``, and their `reverse` flag will
 | 
				
			||||||
 | 
					    be set to ``True`` if that's the case. Note that if this flag is set,
 | 
				
			||||||
 | 
					    `buffer` should be filled in reverse too.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
 | 
				
			||||||
 | 
					        self.client = client
 | 
				
			||||||
 | 
					        self.reverse = reverse
 | 
				
			||||||
 | 
					        self.wait_time = wait_time
 | 
				
			||||||
 | 
					        self.kwargs = kwargs
 | 
				
			||||||
 | 
					        self.limit = max(float('inf') if limit is None else limit, 0)
 | 
				
			||||||
 | 
					        self.left = None
 | 
				
			||||||
 | 
					        self.buffer = None
 | 
				
			||||||
 | 
					        self.index = None
 | 
				
			||||||
 | 
					        self.total = None
 | 
				
			||||||
 | 
					        self.last_load = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _init(self, **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Called when asynchronous initialization is necessary. All keyword
 | 
				
			||||||
 | 
					        arguments passed to `__init__` will be forwarded here, and it's
 | 
				
			||||||
 | 
					        preferable to use named arguments in the subclasses without defaults
 | 
				
			||||||
 | 
					        to avoid forgetting or misspelling any of them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method may ``raise StopAsyncIteration`` if it cannot continue.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __anext__(self):
 | 
				
			||||||
 | 
					        if self.buffer is ():
 | 
				
			||||||
 | 
					            await self._init(**self.kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.index == len(self.buffer):
 | 
				
			||||||
 | 
					            # asyncio will handle times <= 0 to sleep 0 seconds
 | 
				
			||||||
 | 
					            if self.wait_time:
 | 
				
			||||||
 | 
					                await asyncio.sleep(
 | 
				
			||||||
 | 
					                    self.wait_time - (time.time() - self.last_load),
 | 
				
			||||||
 | 
					                    loop=self.client.loop
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.last_load = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.index = 0
 | 
				
			||||||
 | 
					            self.buffer = await self._load_next_chunk()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.buffer:
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result = self.buffer[self.index]
 | 
				
			||||||
 | 
					        self.left -= 1
 | 
				
			||||||
 | 
					        self.index += 1
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __aiter__(self):
 | 
				
			||||||
 | 
					        self.buffer = ()
 | 
				
			||||||
 | 
					        self.index = 0
 | 
				
			||||||
 | 
					        self.last_load = 0
 | 
				
			||||||
 | 
					        self.left = self.limit
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __iter__(self):
 | 
				
			||||||
 | 
					        if self.client.loop.is_running():
 | 
				
			||||||
 | 
					            raise RuntimeError(
 | 
				
			||||||
 | 
					                'You must use "async for" if the event loop '
 | 
				
			||||||
 | 
					                'is running (i.e. you are inside an "async def")'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise NotImplementedError('lol!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Called when the next chunk is necessary.
 | 
				
			||||||
 | 
					        It should *always* return a `list`.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __reversed__(self):
 | 
				
			||||||
 | 
					        self.reverse = not self.reverse
 | 
				
			||||||
 | 
					        return self  # __aiter__ will be called after, too
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user