From 36eb1b1009eaab4836d050cc0e8499b004e93401 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Tue, 26 Feb 2019 20:26:40 +0100 Subject: [PATCH] Create a new RequestIter ABC to deal with iter methods This should make it easier to maintain these methods, increase reusability, and get rid of the async_generator dependency. In the future, people could use this to more easily deal with raw API themselves. --- telethon/client/messages.py | 162 +++++++++++++++++++++++++++++++++++- telethon/requestiter.py | 99 ++++++++++++++++++++++ 2 files changed, 257 insertions(+), 4 deletions(-) create mode 100644 telethon/requestiter.py diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2ad8d4b2..2f928c69 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -2,13 +2,152 @@ import asyncio import itertools import time -from async_generator import async_generator, yield_ - from .messageparse import MessageParseMethods from .uploads import UploadMethods from .buttons import ButtonMethods from .. import helpers, utils, errors from ..tl import types, functions +from ..requestiter import RequestIter + + +class _GetHistoryIter(RequestIter): + async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset): + self.entity = await self.client.get_input_entity(entity) + + # Telegram doesn't like min_id/max_id. If these IDs are low enough + # (starting from last_id - 100), the request will return nothing. + # + # We can emulate their behaviour locally by setting offset = max_id + # and simply stopping once we hit a message with ID <= min_id. + if self.reverse: + offset_id = max(offset_id, min_id) + if offset_id and max_id: + if max_id - offset_id <= 1: + raise StopAsyncIteration + + if not max_id: + max_id = float('inf') + else: + offset_id = max(offset_id, max_id) + if offset_id and min_id: + if offset_id - min_id <= 1: + raise StopAsyncIteration + + if self.reverse: + if offset_id: + offset_id += 1 + else: + offset_id = 1 + + if from_user: + from_user = await self.client.get_input_entity(from_user) + if not isinstance(from_user, ( + types.InputPeerUser, types.InputPeerSelf)): + from_user = None # Ignore from_user unless it's a user + + self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None + + self.request = functions.messages.GetHistoryRequest( + peer=entity, + limit=1, + offset_date=offset_date, + offset_id=offset_id, + min_id=0, + max_id=0, + add_offset=add_offset, + hash=0 + ) + + if self.limit == 0: + # No messages, but we still need to know the total message count + result = await self.client(self.request) + if isinstance(result, types.messages.MessagesNotModified): + self.total = result.count + else: + self.total = getattr(result, 'count', len(result.messages)) + raise StopAsyncIteration + + # When going in reverse we need an offset of `-limit`, but we + # also want to respect what the user passed, so add them together. + if self.reverse: + self.request.add_offset -= batch_size + + if self.wait_time is None: + self.wait_time = 1 if self.limit > 3000 else 0 + + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + self.add_offset = add_offset + self.max_id = max_id + self.min_id = min_id + self.last_id = 0 if self.reverse else float('inf') + + async def _load_next_chunk(self): + result = [] + + self.request.limit = min(self.left, self.batch_size) + if self.reverse and self.request.limit != self.batch_size: + # Remember that we need -limit when going in reverse + self.request.add_offset = self.add_offset - self.request.limit + + r = await self.client(self.request) + self.total = getattr(r, 'count', len(r.messages)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = reversed(r.messages) if self.reverse else r.messages + for message in messages: + if (isinstance(message, types.MessageEmpty) + or self.from_id and message.from_id != self.from_id): + continue + + # TODO We used to yield and return here (stopping the iterator) + # How should we go around that here? + if self.reverse: + if message.id <= self.last_id or message.id >= self.max_id: + break + else: + if message.id >= self.last_id or message.id <= self.min_id: + break + + # There has been reports that on bad connections this method + # was returning duplicated IDs sometimes. Using ``last_id`` + # is an attempt to avoid these duplicates, since the message + # IDs are returned in descending order (or asc if reverse). + self.last_id = message.id + message._finish_init(self.client, entities, self.entity) + result.append(message) + + if len(r.messages) < self.request.limit: + return result + + # Find the first message that's not empty (in some rare cases + # it can happen that the last message is :tl:`MessageEmpty`) + last_message = None + messages = r.messages if self.reverse else reversed(r.messages) + for m in messages: + if not isinstance(m, types.MessageEmpty): + last_message = m + break + + # TODO If it's None, we used to break (ending the iterator) + # Similar case as the return above. + if last_message is not None: + # There are some cases where all the messages we get start + # being empty. This can happen on migrated mega-groups if + # the history was cleared, and we're using search. Telegram + # acts incredibly weird sometimes. Messages are returned but + # only "empty", not their contents. If this is the case we + # should just give up since there won't be any new Message. + self.request.offset_id = last_message.id + self.request.offset_date = last_message.date + if self.reverse: + # We want to skip the one we already have + self.request.offset_id += 1 + + return result class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -17,7 +156,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Message retrieval - @async_generator async def iter_messages( self, entity, limit=None, *, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0, search=None, filter=None, @@ -133,6 +271,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): an higher limit, so you're free to set the ``batch_size`` that you think may be good. """ + # TODO Handle global search + # TODO Handle search + # TODO Handle yield IDs + return _GetHistoryIter( + self, + limit=limit, + wait_time=wait_time, + entity=entity, + reverse=reverse, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset + ) # Note that entity being ``None`` is intended to get messages by # ID under no specific chat, and also to request a global search. if entity: @@ -802,7 +957,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Private methods - @async_generator async def _iter_ids(self, entity, ids, total): """ Special case for `iter_messages` when it should only fetch some IDs. diff --git a/telethon/requestiter.py b/telethon/requestiter.py new file mode 100644 index 00000000..d8ccff53 --- /dev/null +++ b/telethon/requestiter.py @@ -0,0 +1,99 @@ +import abc +import asyncio +import time + + +class RequestIter(abc.ABC): + """ + Helper class to deal with requests that need offsets to iterate. + + It has some facilities, such as automatically sleeping a desired + amount of time between requests if needed (but not more). + + Can be used synchronously if the event loop is not running and + as an asynchronous iterator otherwise. + + `limit` is the total amount of items that the iterator should return. + This is handled on this base class, and will be always ``>= 0``. + + `left` will be reset every time the iterator is used and will indicate + the amount of items that should be emitted left, so that subclasses can + be more efficient and fetch only as many items as they need. + + Iterators may be used with ``reversed``, and their `reverse` flag will + be set to ``True`` if that's the case. Note that if this flag is set, + `buffer` should be filled in reverse too. + """ + def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs): + self.client = client + self.reverse = reverse + self.wait_time = wait_time + self.kwargs = kwargs + self.limit = max(float('inf') if limit is None else limit, 0) + self.left = None + self.buffer = None + self.index = None + self.total = None + self.last_load = None + + async def _init(self, **kwargs): + """ + Called when asynchronous initialization is necessary. All keyword + arguments passed to `__init__` will be forwarded here, and it's + preferable to use named arguments in the subclasses without defaults + to avoid forgetting or misspelling any of them. + + This method may ``raise StopAsyncIteration`` if it cannot continue. + """ + + async def __anext__(self): + if self.buffer is (): + await self._init(**self.kwargs) + + if self.index == len(self.buffer): + # asyncio will handle times <= 0 to sleep 0 seconds + if self.wait_time: + await asyncio.sleep( + self.wait_time - (time.time() - self.last_load), + loop=self.client.loop + ) + self.last_load = time.time() + + self.index = 0 + self.buffer = await self._load_next_chunk() + + if not self.buffer: + raise StopAsyncIteration + + result = self.buffer[self.index] + self.left -= 1 + self.index += 1 + return result + + def __aiter__(self): + self.buffer = () + self.index = 0 + self.last_load = 0 + self.left = self.limit + return self + + def __iter__(self): + if self.client.loop.is_running(): + raise RuntimeError( + 'You must use "async for" if the event loop ' + 'is running (i.e. you are inside an "async def")' + ) + + raise NotImplementedError('lol!') + + @abc.abstractmethod + async def _load_next_chunk(self): + """ + Called when the next chunk is necessary. + It should *always* return a `list`. + """ + raise NotImplementedError + + def __reversed__(self): + self.reverse = not self.reverse + return self # __aiter__ will be called after, too