diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 9750003d..87d49be6 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -446,24 +446,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): a single `Message ` will be returned for convenience instead of a list. """ - # TODO Make RequestIter have a .collect() or similar - total = [0] - kwargs['_total'] = total if len(args) == 1 and 'limit' not in kwargs: if 'min_id' in kwargs and 'max_id' in kwargs: kwargs['limit'] = None else: kwargs['limit'] = 1 - msgs = helpers.TotalList() - async for x in self.iter_messages(*args, **kwargs): - msgs.append(x) - msgs.total = total[0] - if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']): - # Check for empty list to handle InputMessageReplyTo - return msgs[0] if msgs else None + it = self.iter_messages(*args, **kwargs) - return msgs + ids = kwargs.get('ids') + if ids and not utils.is_list_like(ids): + async for message in it: + return message + else: + # Iterator exhausted = empty, to handle InputMessageReplyTo + return None + + return await it.collect() # endregion diff --git a/telethon/requestiter.py b/telethon/requestiter.py index 98a1cfb6..af632389 100644 --- a/telethon/requestiter.py +++ b/telethon/requestiter.py @@ -2,6 +2,8 @@ import abc import asyncio import time +from . import helpers + # TODO There are two types of iterators for requests. # One has a limit of items to retrieve, and the @@ -95,6 +97,17 @@ class RequestIter(abc.ABC): return self.__aiter__() + async def collect(self): + """ + Create a `self` iterator and collect it into a `TotalList` + (a normal list with a `.total` attribute). + """ + result = helpers.TotalList() + async for message in self: + result.append(message) + + return result + @abc.abstractmethod async def _load_next_chunk(self): """