mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-18 04:20:57 +03:00
Simplify filling RequestIter's buffer
This commit is contained in:
parent
202ce1f494
commit
c73b8eda26
|
@ -68,7 +68,6 @@ class _ParticipantsIter(RequestIter):
|
||||||
|
|
||||||
self.total = len(full.full_chat.participants.participants)
|
self.total = len(full.full_chat.participants.participants)
|
||||||
|
|
||||||
result = []
|
|
||||||
users = {user.id: user for user in full.users}
|
users = {user.id: user for user in full.users}
|
||||||
for participant in full.full_chat.participants.participants:
|
for participant in full.full_chat.participants.participants:
|
||||||
user = users[participant.user_id]
|
user = users[participant.user_id]
|
||||||
|
@ -77,26 +76,22 @@ class _ParticipantsIter(RequestIter):
|
||||||
|
|
||||||
user = users[participant.user_id]
|
user = users[participant.user_id]
|
||||||
user.participant = participant
|
user.participant = participant
|
||||||
result.append(user)
|
self.buffer.append(user)
|
||||||
|
|
||||||
self.left = len(result)
|
return True
|
||||||
self.buffer = result
|
|
||||||
else:
|
else:
|
||||||
result = []
|
|
||||||
self.total = 1
|
self.total = 1
|
||||||
if self.limit != 0:
|
if self.limit != 0:
|
||||||
user = await self.client.get_entity(entity)
|
user = await self.client.get_entity(entity)
|
||||||
if self.filter_entity(user):
|
if self.filter_entity(user):
|
||||||
user.participant = None
|
user.participant = None
|
||||||
result.append(user)
|
self.buffer.append(user)
|
||||||
|
|
||||||
self.left = len(result)
|
return True
|
||||||
self.buffer = result
|
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
result = []
|
|
||||||
if not self.requests:
|
if not self.requests:
|
||||||
return result
|
return True
|
||||||
|
|
||||||
# Only care about the limit for the first request
|
# Only care about the limit for the first request
|
||||||
# (small amount of people, won't be aggressive).
|
# (small amount of people, won't be aggressive).
|
||||||
|
@ -106,7 +101,7 @@ class _ParticipantsIter(RequestIter):
|
||||||
# precise with being out of the offset/limit here.
|
# precise with being out of the offset/limit here.
|
||||||
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
|
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
|
||||||
if self.requests[0].offset > self.limit:
|
if self.requests[0].offset > self.limit:
|
||||||
return result
|
return True
|
||||||
|
|
||||||
results = await self.client(self.requests)
|
results = await self.client(self.requests)
|
||||||
for i in reversed(range(len(self.requests))):
|
for i in reversed(range(len(self.requests))):
|
||||||
|
@ -125,9 +120,7 @@ class _ParticipantsIter(RequestIter):
|
||||||
self.seen.add(participant.user_id)
|
self.seen.add(participant.user_id)
|
||||||
user = users[participant.user_id]
|
user = users[participant.user_id]
|
||||||
user.participant = participant
|
user.participant = participant
|
||||||
result.append(user)
|
self.buffer.append(user)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class _AdminLogIter(RequestIter):
|
class _AdminLogIter(RequestIter):
|
||||||
|
@ -163,7 +156,6 @@ class _AdminLogIter(RequestIter):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
result = []
|
|
||||||
self.request.limit = min(self.left, 100)
|
self.request.limit = min(self.left, 100)
|
||||||
r = await self.client(self.request)
|
r = await self.client(self.request)
|
||||||
entities = {utils.get_peer_id(x): x
|
entities = {utils.get_peer_id(x): x
|
||||||
|
@ -184,12 +176,10 @@ class _AdminLogIter(RequestIter):
|
||||||
ev.action.message._finish_init(
|
ev.action.message._finish_init(
|
||||||
self.client, entities, self.entity)
|
self.client, entities, self.entity)
|
||||||
|
|
||||||
result.append(custom.AdminLogEvent(ev, entities))
|
self.buffer.append(custom.AdminLogEvent(ev, entities))
|
||||||
|
|
||||||
if len(r.events) < self.request.limit:
|
if len(r.events) < self.request.limit:
|
||||||
self.left = len(result)
|
return True
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMethods(UserMethods):
|
class ChatMethods(UserMethods):
|
||||||
|
|
|
@ -29,8 +29,6 @@ class _DialogsIter(RequestIter):
|
||||||
self.ignore_migrated = ignore_migrated
|
self.ignore_migrated = ignore_migrated
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
result = []
|
|
||||||
|
|
||||||
self.request.limit = min(self.left, 100)
|
self.request.limit = min(self.left, 100)
|
||||||
r = await self.client(self.request)
|
r = await self.client(self.request)
|
||||||
|
|
||||||
|
@ -62,34 +60,32 @@ class _DialogsIter(RequestIter):
|
||||||
|
|
||||||
if not self.ignore_migrated or getattr(
|
if not self.ignore_migrated or getattr(
|
||||||
cd.entity, 'migrated_to', None) is None:
|
cd.entity, 'migrated_to', None) is None:
|
||||||
result.append(cd)
|
self.buffer.append(cd)
|
||||||
|
|
||||||
if len(r.dialogs) < self.request.limit\
|
if len(r.dialogs) < self.request.limit\
|
||||||
or not isinstance(r, types.messages.DialogsSlice):
|
or not isinstance(r, types.messages.DialogsSlice):
|
||||||
# Less than we requested means we reached the end, or
|
# Less than we requested means we reached the end, or
|
||||||
# we didn't get a DialogsSlice which means we got all.
|
# we didn't get a DialogsSlice which means we got all.
|
||||||
self.left = len(result)
|
return True
|
||||||
|
|
||||||
self.request.offset_date = r.messages[-1].date
|
|
||||||
self.request.offset_peer =\
|
|
||||||
entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
|
||||||
|
|
||||||
if self.request.offset_id == r.messages[-1].id:
|
if self.request.offset_id == r.messages[-1].id:
|
||||||
# In some very rare cases this will get stuck in an infinite
|
# In some very rare cases this will get stuck in an infinite
|
||||||
# loop, where the offsets will get reused over and over. If
|
# loop, where the offsets will get reused over and over. If
|
||||||
# the new offset is the same as the one before, break already.
|
# the new offset is the same as the one before, break already.
|
||||||
self.left = len(result)
|
return True
|
||||||
|
|
||||||
self.request.offset_id = r.messages[-1].id
|
self.request.offset_id = r.messages[-1].id
|
||||||
self.request.exclude_pinned = True
|
self.request.exclude_pinned = True
|
||||||
return result
|
self.request.offset_date = r.messages[-1].date
|
||||||
|
self.request.offset_peer =\
|
||||||
|
entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
||||||
|
|
||||||
|
|
||||||
class _DraftsIter(RequestIter):
|
class _DraftsIter(RequestIter):
|
||||||
async def _init(self, **kwargs):
|
async def _init(self, **kwargs):
|
||||||
r = await self.client(functions.messages.GetAllDraftsRequest())
|
r = await self.client(functions.messages.GetAllDraftsRequest())
|
||||||
self.buffer = [custom.Draft._from_update(self.client, u)
|
self.buffer.extend(custom.Draft._from_update(self.client, u)
|
||||||
for u in r.updates]
|
for u in r.updates)
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -139,8 +139,6 @@ class _MessagesIter(RequestIter):
|
||||||
self.last_id = 0 if self.reverse else float('inf')
|
self.last_id = 0 if self.reverse else float('inf')
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
result = []
|
|
||||||
|
|
||||||
self.request.limit = min(self.left, self.batch_size)
|
self.request.limit = min(self.left, self.batch_size)
|
||||||
if self.reverse and self.request.limit != self.batch_size:
|
if self.reverse and self.request.limit != self.batch_size:
|
||||||
# Remember that we need -limit when going in reverse
|
# Remember that we need -limit when going in reverse
|
||||||
|
@ -159,8 +157,7 @@ class _MessagesIter(RequestIter):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self._message_in_range(message):
|
if not self._message_in_range(message):
|
||||||
self.left = len(result)
|
return True
|
||||||
break
|
|
||||||
|
|
||||||
# There has been reports that on bad connections this method
|
# There has been reports that on bad connections this method
|
||||||
# was returning duplicated IDs sometimes. Using ``last_id``
|
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||||
|
@ -168,15 +165,15 @@ class _MessagesIter(RequestIter):
|
||||||
# IDs are returned in descending order (or asc if reverse).
|
# IDs are returned in descending order (or asc if reverse).
|
||||||
self.last_id = message.id
|
self.last_id = message.id
|
||||||
message._finish_init(self.client, entities, self.entity)
|
message._finish_init(self.client, entities, self.entity)
|
||||||
result.append(message)
|
self.buffer.append(message)
|
||||||
|
|
||||||
if len(r.messages) < self.request.limit:
|
if len(r.messages) < self.request.limit:
|
||||||
self.left = len(result)
|
return True
|
||||||
|
|
||||||
# Get the last message that's not empty (in some rare cases
|
# Get the last message that's not empty (in some rare cases
|
||||||
# it can happen that the last message is :tl:`MessageEmpty`)
|
# it can happen that the last message is :tl:`MessageEmpty`)
|
||||||
if result:
|
if self.buffer:
|
||||||
self._update_offset(result[-1])
|
self._update_offset(self.buffer[-1])
|
||||||
else:
|
else:
|
||||||
# There are some cases where all the messages we get start
|
# There are some cases where all the messages we get start
|
||||||
# being empty. This can happen on migrated mega-groups if
|
# being empty. This can happen on migrated mega-groups if
|
||||||
|
@ -184,9 +181,7 @@ class _MessagesIter(RequestIter):
|
||||||
# acts incredibly weird sometimes. Messages are returned but
|
# acts incredibly weird sometimes. Messages are returned but
|
||||||
# only "empty", not their contents. If this is the case we
|
# only "empty", not their contents. If this is the case we
|
||||||
# should just give up since there won't be any new Message.
|
# should just give up since there won't be any new Message.
|
||||||
self.left = len(result)
|
return True
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _message_in_range(self, message):
|
def _message_in_range(self, message):
|
||||||
"""
|
"""
|
||||||
|
@ -258,7 +253,7 @@ class _IDsIter(RequestIter):
|
||||||
from_id = await self.client.get_peer_id(entity)
|
from_id = await self.client.get_peer_id(entity)
|
||||||
|
|
||||||
if isinstance(r, types.messages.MessagesNotModified):
|
if isinstance(r, types.messages.MessagesNotModified):
|
||||||
self.buffer = [None] * len(ids)
|
self.buffer.extend(None for _ in ids)
|
||||||
return
|
return
|
||||||
|
|
||||||
entities = {utils.get_peer_id(x): x
|
entities = {utils.get_peer_id(x): x
|
||||||
|
@ -272,19 +267,16 @@ class _IDsIter(RequestIter):
|
||||||
# The passed message IDs may not belong to the desired entity
|
# The passed message IDs may not belong to the desired entity
|
||||||
# since the user can enter arbitrary numbers which can belong to
|
# since the user can enter arbitrary numbers which can belong to
|
||||||
# arbitrary chats. Validate these unless ``from_id is None``.
|
# arbitrary chats. Validate these unless ``from_id is None``.
|
||||||
result = []
|
|
||||||
for message in r.messages:
|
for message in r.messages:
|
||||||
if isinstance(message, types.MessageEmpty) or (
|
if isinstance(message, types.MessageEmpty) or (
|
||||||
from_id and message.chat_id != from_id):
|
from_id and message.chat_id != from_id):
|
||||||
result.append(None)
|
self.buffer.append(None)
|
||||||
else:
|
else:
|
||||||
message._finish_init(self.client, entities, entity)
|
message._finish_init(self.client, entities, entity)
|
||||||
result.append(message)
|
self.buffer.append(message)
|
||||||
|
|
||||||
self.buffer = result
|
|
||||||
|
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
return [] # no next chunk, all done in init
|
return True # no next chunk, all done in init
|
||||||
|
|
||||||
|
|
||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
|
|
@ -55,7 +55,8 @@ class RequestIter(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
if self.buffer is ():
|
if self.buffer is None:
|
||||||
|
self.buffer = []
|
||||||
await self._init(**self.kwargs)
|
await self._init(**self.kwargs)
|
||||||
|
|
||||||
if self.left <= 0: # <= 0 because subclasses may change it
|
if self.left <= 0: # <= 0 because subclasses may change it
|
||||||
|
@ -71,7 +72,9 @@ class RequestIter(abc.ABC):
|
||||||
self.last_load = time.time()
|
self.last_load = time.time()
|
||||||
|
|
||||||
self.index = 0
|
self.index = 0
|
||||||
self.buffer = await self._load_next_chunk()
|
self.buffer = []
|
||||||
|
if await self._load_next_chunk():
|
||||||
|
self.left = len(self.buffer)
|
||||||
|
|
||||||
if not self.buffer:
|
if not self.buffer:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
@ -82,7 +85,7 @@ class RequestIter(abc.ABC):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
self.buffer = ()
|
self.buffer = None
|
||||||
self.index = 0
|
self.index = 0
|
||||||
self.last_load = 0
|
self.last_load = 0
|
||||||
self.left = self.limit
|
self.left = self.limit
|
||||||
|
@ -113,7 +116,12 @@ class RequestIter(abc.ABC):
|
||||||
async def _load_next_chunk(self):
|
async def _load_next_chunk(self):
|
||||||
"""
|
"""
|
||||||
Called when the next chunk is necessary.
|
Called when the next chunk is necessary.
|
||||||
It should *always* return a `list`.
|
|
||||||
|
It should extend the `buffer` with new items.
|
||||||
|
|
||||||
|
It should return ``True`` if it's the last chunk,
|
||||||
|
after which moment the method won't be called again
|
||||||
|
during the same iteration.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user