Add method to collect RequestIter into TotalList

This commit is contained in:
Lonami Exo 2019-02-27 10:15:32 +01:00
parent 6d6c1917bc
commit 5b8e6531fa
2 changed files with 23 additions and 11 deletions

View File

@ -446,24 +446,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
a single `Message <telethon.tl.custom.message.Message>` will be a single `Message <telethon.tl.custom.message.Message>` will be
returned for convenience instead of a list. returned for convenience instead of a list.
""" """
# TODO Make RequestIter have a .collect() or similar
total = [0]
kwargs['_total'] = total
if len(args) == 1 and 'limit' not in kwargs: if len(args) == 1 and 'limit' not in kwargs:
if 'min_id' in kwargs and 'max_id' in kwargs: if 'min_id' in kwargs and 'max_id' in kwargs:
kwargs['limit'] = None kwargs['limit'] = None
else: else:
kwargs['limit'] = 1 kwargs['limit'] = 1
msgs = helpers.TotalList() it = self.iter_messages(*args, **kwargs)
async for x in self.iter_messages(*args, **kwargs):
msgs.append(x)
msgs.total = total[0]
if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
# Check for empty list to handle InputMessageReplyTo
return msgs[0] if msgs else None
return msgs ids = kwargs.get('ids')
if ids and not utils.is_list_like(ids):
async for message in it:
return message
else:
# Iterator exhausted = empty, to handle InputMessageReplyTo
return None
return await it.collect()
# endregion # endregion

View File

@ -2,6 +2,8 @@ import abc
import asyncio import asyncio
import time import time
from . import helpers
# TODO There are two types of iterators for requests. # TODO There are two types of iterators for requests.
# One has a limit of items to retrieve, and the # One has a limit of items to retrieve, and the
@ -95,6 +97,17 @@ class RequestIter(abc.ABC):
return self.__aiter__() return self.__aiter__()
async def collect(self):
"""
Create a `self` iterator and collect it into a `TotalList`
(a normal list with a `.total` attribute).
"""
result = helpers.TotalList()
async for message in self:
result.append(message)
return result
@abc.abstractmethod @abc.abstractmethod
async def _load_next_chunk(self): async def _load_next_chunk(self):
""" """