Support getting more than 100 messages by ID

This commit is contained in:
Lonami Exo 2019-08-26 12:16:16 +02:00
parent 61bc8f7fa3
commit d5faf5e8aa

View File

@ -234,33 +234,34 @@ class _MessagesIter(RequestIter):
class _IDsIter(RequestIter): class _IDsIter(RequestIter):
async def _init(self, entity, ids): async def _init(self, entity, ids):
# TODO We never actually split IDs in chunks, but maybe we should
if not utils.is_list_like(ids):
ids = [ids]
elif not ids:
raise StopAsyncIteration
elif self.reverse:
ids = list(reversed(ids))
else:
ids = ids
if entity:
entity = await self.client.get_input_entity(entity)
self.total = len(ids) self.total = len(ids)
self._ids = list(reversed(ids)) if self.reverse else ids
self._offset = 0
self._entity = (await self.client.get_input_entity(entity)) if entity else None
# 30s flood wait every 300 messages (3 requests of 100 each, 30 of 10, etc.)
if self.wait_time is None:
self.wait_time = 10 if self.limit > 300 else 0
async def _load_next_chunk(self):
ids = self._ids[self._offset:self._offset + _MAX_CHUNK_SIZE]
if not ids:
raise StopAsyncIteration
self._offset += _MAX_CHUNK_SIZE
from_id = None # By default, no need to validate from_id from_id = None # By default, no need to validate from_id
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): if isinstance(self._entity, (types.InputChannel, types.InputPeerChannel)):
try: try:
r = await self.client( r = await self.client(
functions.channels.GetMessagesRequest(entity, ids)) functions.channels.GetMessagesRequest(self._entity, ids))
except errors.MessageIdsEmptyError: except errors.MessageIdsEmptyError:
# All IDs were invalid, use a dummy result # All IDs were invalid, use a dummy result
r = types.messages.MessagesNotModified(len(ids)) r = types.messages.MessagesNotModified(len(ids))
else: else:
r = await self.client(functions.messages.GetMessagesRequest(ids)) r = await self.client(functions.messages.GetMessagesRequest(ids))
if entity: if self._entity:
from_id = await self.client.get_peer_id(entity) from_id = await self.client.get_peer_id(self._entity)
if isinstance(r, types.messages.MessagesNotModified): if isinstance(r, types.messages.MessagesNotModified):
self.buffer.extend(None for _ in ids) self.buffer.extend(None for _ in ids)
@ -282,12 +283,9 @@ class _IDsIter(RequestIter):
from_id and message.chat_id != from_id): from_id and message.chat_id != from_id):
self.buffer.append(None) self.buffer.append(None)
else: else:
message._finish_init(self.client, entities, entity) message._finish_init(self.client, entities, self._entity)
self.buffer.append(message) self.buffer.append(message)
async def _load_next_chunk(self):
return True # no next chunk, all done in init
class MessageMethods: class MessageMethods:
@ -385,6 +383,9 @@ class MessageMethods:
the ``FloodWaitError`` as needed. If left to `None`, it will the ``FloodWaitError`` as needed. If left to `None`, it will
default to 1 second only if the limit is higher than 3000. default to 1 second only if the limit is higher than 3000.
If the ``ids`` parameter is used, this time will default
to 10 seconds only if the amount of IDs is higher than 300.
ids (`int`, `list`): ids (`int`, `list`):
A single integer ID (or several IDs) for the message that A single integer ID (or several IDs) for the message that
should be returned. This parameter takes precedence over should be returned. This parameter takes precedence over
@ -442,7 +443,17 @@ class MessageMethods:
print(message.photo) print(message.photo)
""" """
if ids is not None: if ids is not None:
return _IDsIter(self, reverse=reverse, limit=limit, entity=entity, ids=ids) if not utils.is_list_like(ids):
ids = [ids]
return _IDsIter(
client=self,
reverse=reverse,
wait_time=wait_time,
limit=len(ids),
entity=entity,
ids=ids
)
return _MessagesIter( return _MessagesIter(
client=self, client=self,