Create a new RequestIter ABC to deal with iter methods

This should make it easier to maintain these methods, increase
reusability, and get rid of the async_generator dependency.

In the future, people could use this to more easily deal with
raw API themselves.
This commit is contained in:
Lonami Exo 2019-02-26 20:26:40 +01:00
parent 1e4a12d2f7
commit 36eb1b1009
2 changed files with 257 additions and 4 deletions

View File

@ -2,13 +2,152 @@ import asyncio
import itertools import itertools
import time import time
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods from .messageparse import MessageParseMethods
from .uploads import UploadMethods from .uploads import UploadMethods
from .buttons import ButtonMethods from .buttons import ButtonMethods
from .. import helpers, utils, errors from .. import helpers, utils, errors
from ..tl import types, functions from ..tl import types, functions
from ..requestiter import RequestIter
class _GetHistoryIter(RequestIter):
async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset):
self.entity = await self.client.get_input_entity(entity)
# Telegram doesn't like min_id/max_id. If these IDs are low enough
# (starting from last_id - 100), the request will return nothing.
#
# We can emulate their behaviour locally by setting offset = max_id
# and simply stopping once we hit a message with ID <= min_id.
if self.reverse:
offset_id = max(offset_id, min_id)
if offset_id and max_id:
if max_id - offset_id <= 1:
raise StopAsyncIteration
if not max_id:
max_id = float('inf')
else:
offset_id = max(offset_id, max_id)
if offset_id and min_id:
if offset_id - min_id <= 1:
raise StopAsyncIteration
if self.reverse:
if offset_id:
offset_id += 1
else:
offset_id = 1
if from_user:
from_user = await self.client.get_input_entity(from_user)
if not isinstance(from_user, (
types.InputPeerUser, types.InputPeerSelf)):
from_user = None # Ignore from_user unless it's a user
self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None
self.request = functions.messages.GetHistoryRequest(
peer=entity,
limit=1,
offset_date=offset_date,
offset_id=offset_id,
min_id=0,
max_id=0,
add_offset=add_offset,
hash=0
)
if self.limit == 0:
# No messages, but we still need to know the total message count
result = await self.client(self.request)
if isinstance(result, types.messages.MessagesNotModified):
self.total = result.count
else:
self.total = getattr(result, 'count', len(result.messages))
raise StopAsyncIteration
# When going in reverse we need an offset of `-limit`, but we
# also want to respect what the user passed, so add them together.
if self.reverse:
self.request.add_offset -= batch_size
if self.wait_time is None:
self.wait_time = 1 if self.limit > 3000 else 0
# Telegram has a hard limit of 100.
# We don't need to fetch 100 if the limit is less.
self.batch_size = min(max(batch_size, 1), min(100, self.limit))
self.add_offset = add_offset
self.max_id = max_id
self.min_id = min_id
self.last_id = 0 if self.reverse else float('inf')
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, self.batch_size)
if self.reverse and self.request.limit != self.batch_size:
# Remember that we need -limit when going in reverse
self.request.add_offset = self.add_offset - self.request.limit
r = await self.client(self.request)
self.total = getattr(r, 'count', len(r.messages))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = reversed(r.messages) if self.reverse else r.messages
for message in messages:
if (isinstance(message, types.MessageEmpty)
or self.from_id and message.from_id != self.from_id):
continue
# TODO We used to yield and return here (stopping the iterator)
# How should we go around that here?
if self.reverse:
if message.id <= self.last_id or message.id >= self.max_id:
break
else:
if message.id >= self.last_id or message.id <= self.min_id:
break
# There has been reports that on bad connections this method
# was returning duplicated IDs sometimes. Using ``last_id``
# is an attempt to avoid these duplicates, since the message
# IDs are returned in descending order (or asc if reverse).
self.last_id = message.id
message._finish_init(self.client, entities, self.entity)
result.append(message)
if len(r.messages) < self.request.limit:
return result
# Find the first message that's not empty (in some rare cases
# it can happen that the last message is :tl:`MessageEmpty`)
last_message = None
messages = r.messages if self.reverse else reversed(r.messages)
for m in messages:
if not isinstance(m, types.MessageEmpty):
last_message = m
break
# TODO If it's None, we used to break (ending the iterator)
# Similar case as the return above.
if last_message is not None:
# There are some cases where all the messages we get start
# being empty. This can happen on migrated mega-groups if
# the history was cleared, and we're using search. Telegram
# acts incredibly weird sometimes. Messages are returned but
# only "empty", not their contents. If this is the case we
# should just give up since there won't be any new Message.
self.request.offset_id = last_message.id
self.request.offset_date = last_message.date
if self.reverse:
# We want to skip the one we already have
self.request.offset_id += 1
return result
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
@ -17,7 +156,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
# region Message retrieval # region Message retrieval
@async_generator
async def iter_messages( async def iter_messages(
self, entity, limit=None, *, offset_date=None, offset_id=0, self, entity, limit=None, *, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None, max_id=0, min_id=0, add_offset=0, search=None, filter=None,
@ -133,6 +271,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
an higher limit, so you're free to set the ``batch_size`` that an higher limit, so you're free to set the ``batch_size`` that
you think may be good. you think may be good.
""" """
# TODO Handle global search
# TODO Handle search
# TODO Handle yield IDs
return _GetHistoryIter(
self,
limit=limit,
wait_time=wait_time,
entity=entity,
reverse=reverse,
offset_id=offset_id,
min_id=min_id,
max_id=max_id,
from_user=from_user,
batch_size=batch_size,
offset_date=offset_date,
add_offset=add_offset
)
# Note that entity being ``None`` is intended to get messages by # Note that entity being ``None`` is intended to get messages by
# ID under no specific chat, and also to request a global search. # ID under no specific chat, and also to request a global search.
if entity: if entity:
@ -802,7 +957,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
# region Private methods # region Private methods
@async_generator
async def _iter_ids(self, entity, ids, total): async def _iter_ids(self, entity, ids, total):
""" """
Special case for `iter_messages` when it should only fetch some IDs. Special case for `iter_messages` when it should only fetch some IDs.

99
telethon/requestiter.py Normal file
View File

@ -0,0 +1,99 @@
import abc
import asyncio
import time
class RequestIter(abc.ABC):
"""
Helper class to deal with requests that need offsets to iterate.
It has some facilities, such as automatically sleeping a desired
amount of time between requests if needed (but not more).
Can be used synchronously if the event loop is not running and
as an asynchronous iterator otherwise.
`limit` is the total amount of items that the iterator should return.
This is handled on this base class, and will be always ``>= 0``.
`left` will be reset every time the iterator is used and will indicate
the amount of items that should be emitted left, so that subclasses can
be more efficient and fetch only as many items as they need.
Iterators may be used with ``reversed``, and their `reverse` flag will
be set to ``True`` if that's the case. Note that if this flag is set,
`buffer` should be filled in reverse too.
"""
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
self.client = client
self.reverse = reverse
self.wait_time = wait_time
self.kwargs = kwargs
self.limit = max(float('inf') if limit is None else limit, 0)
self.left = None
self.buffer = None
self.index = None
self.total = None
self.last_load = None
async def _init(self, **kwargs):
"""
Called when asynchronous initialization is necessary. All keyword
arguments passed to `__init__` will be forwarded here, and it's
preferable to use named arguments in the subclasses without defaults
to avoid forgetting or misspelling any of them.
This method may ``raise StopAsyncIteration`` if it cannot continue.
"""
async def __anext__(self):
if self.buffer is ():
await self._init(**self.kwargs)
if self.index == len(self.buffer):
# asyncio will handle times <= 0 to sleep 0 seconds
if self.wait_time:
await asyncio.sleep(
self.wait_time - (time.time() - self.last_load),
loop=self.client.loop
)
self.last_load = time.time()
self.index = 0
self.buffer = await self._load_next_chunk()
if not self.buffer:
raise StopAsyncIteration
result = self.buffer[self.index]
self.left -= 1
self.index += 1
return result
def __aiter__(self):
self.buffer = ()
self.index = 0
self.last_load = 0
self.left = self.limit
return self
def __iter__(self):
if self.client.loop.is_running():
raise RuntimeError(
'You must use "async for" if the event loop '
'is running (i.e. you are inside an "async def")'
)
raise NotImplementedError('lol!')
@abc.abstractmethod
async def _load_next_chunk(self):
"""
Called when the next chunk is necessary.
It should *always* return a `list`.
"""
raise NotImplementedError
def __reversed__(self):
self.reverse = not self.reverse
return self # __aiter__ will be called after, too