mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-22 17:36:34 +03:00
Create a new RequestIter ABC to deal with iter methods
This should make it easier to maintain these methods, increase reusability, and get rid of the async_generator dependency. In the future, people could use this to more easily deal with raw API themselves.
This commit is contained in:
parent
1e4a12d2f7
commit
36eb1b1009
|
@ -2,13 +2,152 @@ import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
|
||||||
|
|
||||||
from .messageparse import MessageParseMethods
|
from .messageparse import MessageParseMethods
|
||||||
from .uploads import UploadMethods
|
from .uploads import UploadMethods
|
||||||
from .buttons import ButtonMethods
|
from .buttons import ButtonMethods
|
||||||
from .. import helpers, utils, errors
|
from .. import helpers, utils, errors
|
||||||
from ..tl import types, functions
|
from ..tl import types, functions
|
||||||
|
from ..requestiter import RequestIter
|
||||||
|
|
||||||
|
|
||||||
|
class _GetHistoryIter(RequestIter):
|
||||||
|
async def _init(self, entity, offset_id, min_id, max_id, from_user, batch_size, offset_date, add_offset):
|
||||||
|
self.entity = await self.client.get_input_entity(entity)
|
||||||
|
|
||||||
|
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
||||||
|
# (starting from last_id - 100), the request will return nothing.
|
||||||
|
#
|
||||||
|
# We can emulate their behaviour locally by setting offset = max_id
|
||||||
|
# and simply stopping once we hit a message with ID <= min_id.
|
||||||
|
if self.reverse:
|
||||||
|
offset_id = max(offset_id, min_id)
|
||||||
|
if offset_id and max_id:
|
||||||
|
if max_id - offset_id <= 1:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if not max_id:
|
||||||
|
max_id = float('inf')
|
||||||
|
else:
|
||||||
|
offset_id = max(offset_id, max_id)
|
||||||
|
if offset_id and min_id:
|
||||||
|
if offset_id - min_id <= 1:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if self.reverse:
|
||||||
|
if offset_id:
|
||||||
|
offset_id += 1
|
||||||
|
else:
|
||||||
|
offset_id = 1
|
||||||
|
|
||||||
|
if from_user:
|
||||||
|
from_user = await self.client.get_input_entity(from_user)
|
||||||
|
if not isinstance(from_user, (
|
||||||
|
types.InputPeerUser, types.InputPeerSelf)):
|
||||||
|
from_user = None # Ignore from_user unless it's a user
|
||||||
|
|
||||||
|
self.from_id = (await self.client.get_peer_id(from_user)) if from_user else None
|
||||||
|
|
||||||
|
self.request = functions.messages.GetHistoryRequest(
|
||||||
|
peer=entity,
|
||||||
|
limit=1,
|
||||||
|
offset_date=offset_date,
|
||||||
|
offset_id=offset_id,
|
||||||
|
min_id=0,
|
||||||
|
max_id=0,
|
||||||
|
add_offset=add_offset,
|
||||||
|
hash=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.limit == 0:
|
||||||
|
# No messages, but we still need to know the total message count
|
||||||
|
result = await self.client(self.request)
|
||||||
|
if isinstance(result, types.messages.MessagesNotModified):
|
||||||
|
self.total = result.count
|
||||||
|
else:
|
||||||
|
self.total = getattr(result, 'count', len(result.messages))
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
# When going in reverse we need an offset of `-limit`, but we
|
||||||
|
# also want to respect what the user passed, so add them together.
|
||||||
|
if self.reverse:
|
||||||
|
self.request.add_offset -= batch_size
|
||||||
|
|
||||||
|
if self.wait_time is None:
|
||||||
|
self.wait_time = 1 if self.limit > 3000 else 0
|
||||||
|
|
||||||
|
# Telegram has a hard limit of 100.
|
||||||
|
# We don't need to fetch 100 if the limit is less.
|
||||||
|
self.batch_size = min(max(batch_size, 1), min(100, self.limit))
|
||||||
|
self.add_offset = add_offset
|
||||||
|
self.max_id = max_id
|
||||||
|
self.min_id = min_id
|
||||||
|
self.last_id = 0 if self.reverse else float('inf')
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
result = []
|
||||||
|
|
||||||
|
self.request.limit = min(self.left, self.batch_size)
|
||||||
|
if self.reverse and self.request.limit != self.batch_size:
|
||||||
|
# Remember that we need -limit when going in reverse
|
||||||
|
self.request.add_offset = self.add_offset - self.request.limit
|
||||||
|
|
||||||
|
r = await self.client(self.request)
|
||||||
|
self.total = getattr(r, 'count', len(r.messages))
|
||||||
|
|
||||||
|
entities = {utils.get_peer_id(x): x
|
||||||
|
for x in itertools.chain(r.users, r.chats)}
|
||||||
|
|
||||||
|
messages = reversed(r.messages) if self.reverse else r.messages
|
||||||
|
for message in messages:
|
||||||
|
if (isinstance(message, types.MessageEmpty)
|
||||||
|
or self.from_id and message.from_id != self.from_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO We used to yield and return here (stopping the iterator)
|
||||||
|
# How should we go around that here?
|
||||||
|
if self.reverse:
|
||||||
|
if message.id <= self.last_id or message.id >= self.max_id:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if message.id >= self.last_id or message.id <= self.min_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
# There has been reports that on bad connections this method
|
||||||
|
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||||
|
# is an attempt to avoid these duplicates, since the message
|
||||||
|
# IDs are returned in descending order (or asc if reverse).
|
||||||
|
self.last_id = message.id
|
||||||
|
message._finish_init(self.client, entities, self.entity)
|
||||||
|
result.append(message)
|
||||||
|
|
||||||
|
if len(r.messages) < self.request.limit:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Find the first message that's not empty (in some rare cases
|
||||||
|
# it can happen that the last message is :tl:`MessageEmpty`)
|
||||||
|
last_message = None
|
||||||
|
messages = r.messages if self.reverse else reversed(r.messages)
|
||||||
|
for m in messages:
|
||||||
|
if not isinstance(m, types.MessageEmpty):
|
||||||
|
last_message = m
|
||||||
|
break
|
||||||
|
|
||||||
|
# TODO If it's None, we used to break (ending the iterator)
|
||||||
|
# Similar case as the return above.
|
||||||
|
if last_message is not None:
|
||||||
|
# There are some cases where all the messages we get start
|
||||||
|
# being empty. This can happen on migrated mega-groups if
|
||||||
|
# the history was cleared, and we're using search. Telegram
|
||||||
|
# acts incredibly weird sometimes. Messages are returned but
|
||||||
|
# only "empty", not their contents. If this is the case we
|
||||||
|
# should just give up since there won't be any new Message.
|
||||||
|
self.request.offset_id = last_message.id
|
||||||
|
self.request.offset_date = last_message.date
|
||||||
|
if self.reverse:
|
||||||
|
# We want to skip the one we already have
|
||||||
|
self.request.offset_id += 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
@ -17,7 +156,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
|
||||||
# region Message retrieval
|
# region Message retrieval
|
||||||
|
|
||||||
@async_generator
|
|
||||||
async def iter_messages(
|
async def iter_messages(
|
||||||
self, entity, limit=None, *, offset_date=None, offset_id=0,
|
self, entity, limit=None, *, offset_date=None, offset_id=0,
|
||||||
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
||||||
|
@ -133,6 +271,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
an higher limit, so you're free to set the ``batch_size`` that
|
an higher limit, so you're free to set the ``batch_size`` that
|
||||||
you think may be good.
|
you think may be good.
|
||||||
"""
|
"""
|
||||||
|
# TODO Handle global search
|
||||||
|
# TODO Handle search
|
||||||
|
# TODO Handle yield IDs
|
||||||
|
return _GetHistoryIter(
|
||||||
|
self,
|
||||||
|
limit=limit,
|
||||||
|
wait_time=wait_time,
|
||||||
|
entity=entity,
|
||||||
|
reverse=reverse,
|
||||||
|
offset_id=offset_id,
|
||||||
|
min_id=min_id,
|
||||||
|
max_id=max_id,
|
||||||
|
from_user=from_user,
|
||||||
|
batch_size=batch_size,
|
||||||
|
offset_date=offset_date,
|
||||||
|
add_offset=add_offset
|
||||||
|
)
|
||||||
# Note that entity being ``None`` is intended to get messages by
|
# Note that entity being ``None`` is intended to get messages by
|
||||||
# ID under no specific chat, and also to request a global search.
|
# ID under no specific chat, and also to request a global search.
|
||||||
if entity:
|
if entity:
|
||||||
|
@ -802,7 +957,6 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
|
||||||
# region Private methods
|
# region Private methods
|
||||||
|
|
||||||
@async_generator
|
|
||||||
async def _iter_ids(self, entity, ids, total):
|
async def _iter_ids(self, entity, ids, total):
|
||||||
"""
|
"""
|
||||||
Special case for `iter_messages` when it should only fetch some IDs.
|
Special case for `iter_messages` when it should only fetch some IDs.
|
||||||
|
|
99
telethon/requestiter.py
Normal file
99
telethon/requestiter.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RequestIter(abc.ABC):
|
||||||
|
"""
|
||||||
|
Helper class to deal with requests that need offsets to iterate.
|
||||||
|
|
||||||
|
It has some facilities, such as automatically sleeping a desired
|
||||||
|
amount of time between requests if needed (but not more).
|
||||||
|
|
||||||
|
Can be used synchronously if the event loop is not running and
|
||||||
|
as an asynchronous iterator otherwise.
|
||||||
|
|
||||||
|
`limit` is the total amount of items that the iterator should return.
|
||||||
|
This is handled on this base class, and will be always ``>= 0``.
|
||||||
|
|
||||||
|
`left` will be reset every time the iterator is used and will indicate
|
||||||
|
the amount of items that should be emitted left, so that subclasses can
|
||||||
|
be more efficient and fetch only as many items as they need.
|
||||||
|
|
||||||
|
Iterators may be used with ``reversed``, and their `reverse` flag will
|
||||||
|
be set to ``True`` if that's the case. Note that if this flag is set,
|
||||||
|
`buffer` should be filled in reverse too.
|
||||||
|
"""
|
||||||
|
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
|
||||||
|
self.client = client
|
||||||
|
self.reverse = reverse
|
||||||
|
self.wait_time = wait_time
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.limit = max(float('inf') if limit is None else limit, 0)
|
||||||
|
self.left = None
|
||||||
|
self.buffer = None
|
||||||
|
self.index = None
|
||||||
|
self.total = None
|
||||||
|
self.last_load = None
|
||||||
|
|
||||||
|
async def _init(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Called when asynchronous initialization is necessary. All keyword
|
||||||
|
arguments passed to `__init__` will be forwarded here, and it's
|
||||||
|
preferable to use named arguments in the subclasses without defaults
|
||||||
|
to avoid forgetting or misspelling any of them.
|
||||||
|
|
||||||
|
This method may ``raise StopAsyncIteration`` if it cannot continue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.buffer is ():
|
||||||
|
await self._init(**self.kwargs)
|
||||||
|
|
||||||
|
if self.index == len(self.buffer):
|
||||||
|
# asyncio will handle times <= 0 to sleep 0 seconds
|
||||||
|
if self.wait_time:
|
||||||
|
await asyncio.sleep(
|
||||||
|
self.wait_time - (time.time() - self.last_load),
|
||||||
|
loop=self.client.loop
|
||||||
|
)
|
||||||
|
self.last_load = time.time()
|
||||||
|
|
||||||
|
self.index = 0
|
||||||
|
self.buffer = await self._load_next_chunk()
|
||||||
|
|
||||||
|
if not self.buffer:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
result = self.buffer[self.index]
|
||||||
|
self.left -= 1
|
||||||
|
self.index += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
self.buffer = ()
|
||||||
|
self.index = 0
|
||||||
|
self.last_load = 0
|
||||||
|
self.left = self.limit
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.client.loop.is_running():
|
||||||
|
raise RuntimeError(
|
||||||
|
'You must use "async for" if the event loop '
|
||||||
|
'is running (i.e. you are inside an "async def")'
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError('lol!')
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
"""
|
||||||
|
Called when the next chunk is necessary.
|
||||||
|
It should *always* return a `list`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __reversed__(self):
|
||||||
|
self.reverse = not self.reverse
|
||||||
|
return self # __aiter__ will be called after, too
|
Loading…
Reference in New Issue
Block a user