Implement iterator over message by IDs

This commit is contained in:
Lonami Exo 2019-02-27 10:04:12 +01:00
parent 60606b9994
commit 6d6c1917bc
2 changed files with 48 additions and 51 deletions

View File

@ -228,7 +228,8 @@ class _MessagesIter(RequestIter):
class _IDsIter(RequestIter): class _IDsIter(RequestIter):
async def _init(self, entity, from_user, ids): async def _init(self, entity, ids):
# TODO We never actually split IDs in chunks, but maybe we should
if not utils.is_list_like(ids): if not utils.is_list_like(ids):
self.ids = [ids] self.ids = [ids]
elif not ids: elif not ids:
@ -238,10 +239,52 @@ class _IDsIter(RequestIter):
else: else:
self.ids = ids self.ids = ids
raise NotImplementedError if entity:
entity = await self.client.get_input_entity(entity)
self.total = len(ids)
from_id = None # By default, no need to validate from_id
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
try:
r = await self.client(
functions.channels.GetMessagesRequest(entity, ids))
except errors.MessageIdsEmptyError:
# All IDs were invalid, use a dummy result
r = types.messages.MessagesNotModified(len(ids))
else:
r = await self.client(functions.messages.GetMessagesRequest(ids))
if entity:
from_id = await self.client.get_peer_id(entity)
if isinstance(r, types.messages.MessagesNotModified):
self.buffer = [None] * len(ids)
return
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
# Telegram seems to return the messages in the order in which
# we asked them for, so we don't need to check it ourselves,
# unless some messages were invalid in which case Telegram
# may decide to not send them at all.
#
# The passed message IDs may not belong to the desired entity
# since the user can enter arbitrary numbers which can belong to
# arbitrary chats. Validate these unless ``from_id is None``.
result = []
for message in r.messages:
if isinstance(message, types.MessageEmpty) or (
from_id and message.chat_id != from_id):
result.append(None)
else:
message._finish_init(self.client, entities, entity)
result.append(message)
self.buffer = result
async def _load_next_chunk(self): async def _load_next_chunk(self):
raise NotImplementedError return [] # no next chunk, all done in init
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
@ -850,51 +893,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
# endregion # endregion
# endregion # endregion
# region Private methods
async def _iter_ids(self, entity, ids, total):
"""
Special case for `iter_messages` when it should only fetch some IDs.
"""
if total:
total[0] = len(ids)
from_id = None # By default, no need to validate from_id
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
try:
r = await self(
functions.channels.GetMessagesRequest(entity, ids))
except errors.MessageIdsEmptyError:
# All IDs were invalid, use a dummy result
r = types.messages.MessagesNotModified(len(ids))
else:
r = await self(functions.messages.GetMessagesRequest(ids))
if entity:
from_id = utils.get_peer_id(entity)
if isinstance(r, types.messages.MessagesNotModified):
for _ in ids:
await yield_(None)
return
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
# Telegram seems to return the messages in the order in which
# we asked them for, so we don't need to check it ourselves,
# unless some messages were invalid in which case Telegram
# may decide to not send them at all.
#
# The passed message IDs may not belong to the desired entity
# since the user can enter arbitrary numbers which can belong to
# arbitrary chats. Validate these unless ``from_id is None``.
for message in r.messages:
if isinstance(message, types.MessageEmpty) or (
from_id and message.chat_id != from_id):
await yield_(None)
else:
message._finish_init(self, entities, entity)
await yield_(message)
# endregion

View File

@ -48,6 +48,8 @@ class RequestIter(abc.ABC):
to avoid forgetting or misspelling any of them. to avoid forgetting or misspelling any of them.
This method may ``raise StopAsyncIteration`` if it cannot continue. This method may ``raise StopAsyncIteration`` if it cannot continue.
This method may actually fill the initial buffer if it needs to.
""" """
async def __anext__(self): async def __anext__(self):