Simplify filling RequestIter's buffer

This commit is contained in:
Lonami Exo 2019-02-27 11:24:47 +01:00
parent 202ce1f494
commit c73b8eda26
4 changed files with 39 additions and 53 deletions

View File

@ -68,7 +68,6 @@ class _ParticipantsIter(RequestIter):
self.total = len(full.full_chat.participants.participants)
result = []
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
@ -77,26 +76,22 @@ class _ParticipantsIter(RequestIter):
user = users[participant.user_id]
user.participant = participant
result.append(user)
self.buffer.append(user)
self.left = len(result)
self.buffer = result
return True
else:
result = []
self.total = 1
if self.limit != 0:
user = await self.client.get_entity(entity)
if self.filter_entity(user):
user.participant = None
result.append(user)
self.buffer.append(user)
self.left = len(result)
self.buffer = result
return True
async def _load_next_chunk(self):
result = []
if not self.requests:
return result
return True
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
@ -106,7 +101,7 @@ class _ParticipantsIter(RequestIter):
# precise with being out of the offset/limit here.
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
if self.requests[0].offset > self.limit:
return result
return True
results = await self.client(self.requests)
for i in reversed(range(len(self.requests))):
@ -125,9 +120,7 @@ class _ParticipantsIter(RequestIter):
self.seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
result.append(user)
return result
self.buffer.append(user)
class _AdminLogIter(RequestIter):
@ -163,7 +156,6 @@ class _AdminLogIter(RequestIter):
)
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
entities = {utils.get_peer_id(x): x
@ -184,12 +176,10 @@ class _AdminLogIter(RequestIter):
ev.action.message._finish_init(
self.client, entities, self.entity)
result.append(custom.AdminLogEvent(ev, entities))
self.buffer.append(custom.AdminLogEvent(ev, entities))
if len(r.events) < self.request.limit:
self.left = len(result)
return result
return True
class ChatMethods(UserMethods):

View File

@ -29,8 +29,6 @@ class _DialogsIter(RequestIter):
self.ignore_migrated = ignore_migrated
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
@ -62,34 +60,32 @@ class _DialogsIter(RequestIter):
if not self.ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
result.append(cd)
self.buffer.append(cd)
if len(r.dialogs) < self.request.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
self.left = len(result)
self.request.offset_date = r.messages[-1].date
self.request.offset_peer =\
entities[utils.get_peer_id(r.dialogs[-1].peer)]
return True
if self.request.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
self.left = len(result)
return True
self.request.offset_id = r.messages[-1].id
self.request.exclude_pinned = True
return result
self.request.offset_date = r.messages[-1].date
self.request.offset_peer =\
entities[utils.get_peer_id(r.dialogs[-1].peer)]
class _DraftsIter(RequestIter):
async def _init(self, **kwargs):
r = await self.client(functions.messages.GetAllDraftsRequest())
self.buffer = [custom.Draft._from_update(self.client, u)
for u in r.updates]
self.buffer.extend(custom.Draft._from_update(self.client, u)
for u in r.updates)
async def _load_next_chunk(self):
return []

View File

@ -139,8 +139,6 @@ class _MessagesIter(RequestIter):
self.last_id = 0 if self.reverse else float('inf')
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, self.batch_size)
if self.reverse and self.request.limit != self.batch_size:
# Remember that we need -limit when going in reverse
@ -159,8 +157,7 @@ class _MessagesIter(RequestIter):
continue
if not self._message_in_range(message):
self.left = len(result)
break
return True
# There has been reports that on bad connections this method
# was returning duplicated IDs sometimes. Using ``last_id``
@ -168,15 +165,15 @@ class _MessagesIter(RequestIter):
# IDs are returned in descending order (or asc if reverse).
self.last_id = message.id
message._finish_init(self.client, entities, self.entity)
result.append(message)
self.buffer.append(message)
if len(r.messages) < self.request.limit:
self.left = len(result)
return True
# Get the last message that's not empty (in some rare cases
# it can happen that the last message is :tl:`MessageEmpty`)
if result:
self._update_offset(result[-1])
if self.buffer:
self._update_offset(self.buffer[-1])
else:
# There are some cases where all the messages we get start
# being empty. This can happen on migrated mega-groups if
@ -184,9 +181,7 @@ class _MessagesIter(RequestIter):
# acts incredibly weird sometimes. Messages are returned but
# only "empty", not their contents. If this is the case we
# should just give up since there won't be any new Message.
self.left = len(result)
return result
return True
def _message_in_range(self, message):
"""
@ -258,7 +253,7 @@ class _IDsIter(RequestIter):
from_id = await self.client.get_peer_id(entity)
if isinstance(r, types.messages.MessagesNotModified):
self.buffer = [None] * len(ids)
self.buffer.extend(None for _ in ids)
return
entities = {utils.get_peer_id(x): x
@ -272,19 +267,16 @@ class _IDsIter(RequestIter):
# The passed message IDs may not belong to the desired entity
# since the user can enter arbitrary numbers which can belong to
# arbitrary chats. Validate these unless ``from_id is None``.
result = []
for message in r.messages:
if isinstance(message, types.MessageEmpty) or (
from_id and message.chat_id != from_id):
result.append(None)
self.buffer.append(None)
else:
message._finish_init(self.client, entities, entity)
result.append(message)
self.buffer = result
self.buffer.append(message)
async def _load_next_chunk(self):
return [] # no next chunk, all done in init
return True # no next chunk, all done in init
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):

View File

@ -55,7 +55,8 @@ class RequestIter(abc.ABC):
"""
async def __anext__(self):
if self.buffer is ():
if self.buffer is None:
self.buffer = []
await self._init(**self.kwargs)
if self.left <= 0: # <= 0 because subclasses may change it
@ -71,7 +72,9 @@ class RequestIter(abc.ABC):
self.last_load = time.time()
self.index = 0
self.buffer = await self._load_next_chunk()
self.buffer = []
if await self._load_next_chunk():
self.left = len(self.buffer)
if not self.buffer:
raise StopAsyncIteration
@ -82,7 +85,7 @@ class RequestIter(abc.ABC):
return result
def __aiter__(self):
self.buffer = ()
self.buffer = None
self.index = 0
self.last_load = 0
self.left = self.limit
@ -113,7 +116,12 @@ class RequestIter(abc.ABC):
async def _load_next_chunk(self):
"""
Called when the next chunk is necessary.
It should *always* return a `list`.
It should extend the `buffer` with new items.
It should return ``True`` if it's the last chunk,
after which moment the method won't be called again
during the same iteration.
"""
raise NotImplementedError