mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-30 23:47:33 +03:00 
			
		
		
		
	Simplify filling RequestIter's buffer
This commit is contained in:
		
							parent
							
								
									202ce1f494
								
							
						
					
					
						commit
						c73b8eda26
					
				|  | @ -68,7 +68,6 @@ class _ParticipantsIter(RequestIter): | |||
| 
 | ||||
|             self.total = len(full.full_chat.participants.participants) | ||||
| 
 | ||||
|             result = [] | ||||
|             users = {user.id: user for user in full.users} | ||||
|             for participant in full.full_chat.participants.participants: | ||||
|                 user = users[participant.user_id] | ||||
|  | @ -77,26 +76,22 @@ class _ParticipantsIter(RequestIter): | |||
| 
 | ||||
|                 user = users[participant.user_id] | ||||
|                 user.participant = participant | ||||
|                 result.append(user) | ||||
|                 self.buffer.append(user) | ||||
| 
 | ||||
|             self.left = len(result) | ||||
|             self.buffer = result | ||||
|             return True | ||||
|         else: | ||||
|             result = [] | ||||
|             self.total = 1 | ||||
|             if self.limit != 0: | ||||
|                 user = await self.client.get_entity(entity) | ||||
|                 if self.filter_entity(user): | ||||
|                     user.participant = None | ||||
|                     result.append(user) | ||||
|                     self.buffer.append(user) | ||||
| 
 | ||||
|             self.left = len(result) | ||||
|             self.buffer = result | ||||
|             return True | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
|         if not self.requests: | ||||
|             return result | ||||
|             return True | ||||
| 
 | ||||
|         # Only care about the limit for the first request | ||||
|         # (small amount of people, won't be aggressive). | ||||
|  | @ -106,7 +101,7 @@ class _ParticipantsIter(RequestIter): | |||
|         # precise with being out of the offset/limit here. | ||||
|         self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) | ||||
|         if self.requests[0].offset > self.limit: | ||||
|             return result | ||||
|             return True | ||||
| 
 | ||||
|         results = await self.client(self.requests) | ||||
|         for i in reversed(range(len(self.requests))): | ||||
|  | @ -125,9 +120,7 @@ class _ParticipantsIter(RequestIter): | |||
|                 self.seen.add(participant.user_id) | ||||
|                 user = users[participant.user_id] | ||||
|                 user.participant = participant | ||||
|                 result.append(user) | ||||
| 
 | ||||
|         return result | ||||
|                 self.buffer.append(user) | ||||
| 
 | ||||
| 
 | ||||
| class _AdminLogIter(RequestIter): | ||||
|  | @ -163,7 +156,6 @@ class _AdminLogIter(RequestIter): | |||
|         ) | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
|         self.request.limit = min(self.left, 100) | ||||
|         r = await self.client(self.request) | ||||
|         entities = {utils.get_peer_id(x): x | ||||
|  | @ -184,12 +176,10 @@ class _AdminLogIter(RequestIter): | |||
|                 ev.action.message._finish_init( | ||||
|                     self.client, entities, self.entity) | ||||
| 
 | ||||
|             result.append(custom.AdminLogEvent(ev, entities)) | ||||
|             self.buffer.append(custom.AdminLogEvent(ev, entities)) | ||||
| 
 | ||||
|         if len(r.events) < self.request.limit: | ||||
|             self.left = len(result) | ||||
| 
 | ||||
|         return result | ||||
|             return True | ||||
| 
 | ||||
| 
 | ||||
| class ChatMethods(UserMethods): | ||||
|  |  | |||
|  | @ -29,8 +29,6 @@ class _DialogsIter(RequestIter): | |||
|         self.ignore_migrated = ignore_migrated | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
| 
 | ||||
|         self.request.limit = min(self.left, 100) | ||||
|         r = await self.client(self.request) | ||||
| 
 | ||||
|  | @ -62,34 +60,32 @@ class _DialogsIter(RequestIter): | |||
| 
 | ||||
|                 if not self.ignore_migrated or getattr( | ||||
|                         cd.entity, 'migrated_to', None) is None: | ||||
|                     result.append(cd) | ||||
|                     self.buffer.append(cd) | ||||
| 
 | ||||
|         if len(r.dialogs) < self.request.limit\ | ||||
|                 or not isinstance(r, types.messages.DialogsSlice): | ||||
|             # Less than we requested means we reached the end, or | ||||
|             # we didn't get a DialogsSlice which means we got all. | ||||
|             self.left = len(result) | ||||
| 
 | ||||
|         self.request.offset_date = r.messages[-1].date | ||||
|         self.request.offset_peer =\ | ||||
|             entities[utils.get_peer_id(r.dialogs[-1].peer)] | ||||
|             return True | ||||
| 
 | ||||
|         if self.request.offset_id == r.messages[-1].id: | ||||
|             # In some very rare cases this will get stuck in an infinite | ||||
|             # loop, where the offsets will get reused over and over. If | ||||
|             # the new offset is the same as the one before, break already. | ||||
|             self.left = len(result) | ||||
|             return True | ||||
| 
 | ||||
|         self.request.offset_id = r.messages[-1].id | ||||
|         self.request.exclude_pinned = True | ||||
|         return result | ||||
|         self.request.offset_date = r.messages[-1].date | ||||
|         self.request.offset_peer =\ | ||||
|             entities[utils.get_peer_id(r.dialogs[-1].peer)] | ||||
| 
 | ||||
| 
 | ||||
| class _DraftsIter(RequestIter): | ||||
|     async def _init(self, **kwargs): | ||||
|         r = await self.client(functions.messages.GetAllDraftsRequest()) | ||||
|         self.buffer = [custom.Draft._from_update(self.client, u) | ||||
|                        for u in r.updates] | ||||
|         self.buffer.extend(custom.Draft._from_update(self.client, u) | ||||
|                            for u in r.updates) | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         return [] | ||||
|  |  | |||
|  | @ -139,8 +139,6 @@ class _MessagesIter(RequestIter): | |||
|         self.last_id = 0 if self.reverse else float('inf') | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
| 
 | ||||
|         self.request.limit = min(self.left, self.batch_size) | ||||
|         if self.reverse and self.request.limit != self.batch_size: | ||||
|             # Remember that we need -limit when going in reverse | ||||
|  | @ -159,8 +157,7 @@ class _MessagesIter(RequestIter): | |||
|                 continue | ||||
| 
 | ||||
|             if not self._message_in_range(message): | ||||
|                 self.left = len(result) | ||||
|                 break | ||||
|                 return True | ||||
| 
 | ||||
|             # There has been reports that on bad connections this method | ||||
|             # was returning duplicated IDs sometimes. Using ``last_id`` | ||||
|  | @ -168,15 +165,15 @@ class _MessagesIter(RequestIter): | |||
|             # IDs are returned in descending order (or asc if reverse). | ||||
|             self.last_id = message.id | ||||
|             message._finish_init(self.client, entities, self.entity) | ||||
|             result.append(message) | ||||
|             self.buffer.append(message) | ||||
| 
 | ||||
|         if len(r.messages) < self.request.limit: | ||||
|             self.left = len(result) | ||||
|             return True | ||||
| 
 | ||||
|         # Get the last message that's not empty (in some rare cases | ||||
|         # it can happen that the last message is :tl:`MessageEmpty`) | ||||
|         if result: | ||||
|             self._update_offset(result[-1]) | ||||
|         if self.buffer: | ||||
|             self._update_offset(self.buffer[-1]) | ||||
|         else: | ||||
|             # There are some cases where all the messages we get start | ||||
|             # being empty. This can happen on migrated mega-groups if | ||||
|  | @ -184,9 +181,7 @@ class _MessagesIter(RequestIter): | |||
|             # acts incredibly weird sometimes. Messages are returned but | ||||
|             # only "empty", not their contents. If this is the case we | ||||
|             # should just give up since there won't be any new Message. | ||||
|             self.left = len(result) | ||||
| 
 | ||||
|         return result | ||||
|             return True | ||||
| 
 | ||||
|     def _message_in_range(self, message): | ||||
|         """ | ||||
|  | @ -258,7 +253,7 @@ class _IDsIter(RequestIter): | |||
|                 from_id = await self.client.get_peer_id(entity) | ||||
| 
 | ||||
|         if isinstance(r, types.messages.MessagesNotModified): | ||||
|             self.buffer = [None] * len(ids) | ||||
|             self.buffer.extend(None for _ in ids) | ||||
|             return | ||||
| 
 | ||||
|         entities = {utils.get_peer_id(x): x | ||||
|  | @ -272,19 +267,16 @@ class _IDsIter(RequestIter): | |||
|         # The passed message IDs may not belong to the desired entity | ||||
|         # since the user can enter arbitrary numbers which can belong to | ||||
|         # arbitrary chats. Validate these unless ``from_id is None``. | ||||
|         result = [] | ||||
|         for message in r.messages: | ||||
|             if isinstance(message, types.MessageEmpty) or ( | ||||
|                     from_id and message.chat_id != from_id): | ||||
|                 result.append(None) | ||||
|                 self.buffer.append(None) | ||||
|             else: | ||||
|                 message._finish_init(self.client, entities, entity) | ||||
|                 result.append(message) | ||||
| 
 | ||||
|         self.buffer = result | ||||
|                 self.buffer.append(message) | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         return []  # no next chunk, all done in init | ||||
|         return True  # no next chunk, all done in init | ||||
| 
 | ||||
| 
 | ||||
| class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): | ||||
|  |  | |||
|  | @ -55,7 +55,8 @@ class RequestIter(abc.ABC): | |||
|         """ | ||||
| 
 | ||||
|     async def __anext__(self): | ||||
|         if self.buffer is (): | ||||
|         if self.buffer is None: | ||||
|             self.buffer = [] | ||||
|             await self._init(**self.kwargs) | ||||
| 
 | ||||
|         if self.left <= 0:  # <= 0 because subclasses may change it | ||||
|  | @ -71,7 +72,9 @@ class RequestIter(abc.ABC): | |||
|                 self.last_load = time.time() | ||||
| 
 | ||||
|             self.index = 0 | ||||
|             self.buffer = await self._load_next_chunk() | ||||
|             self.buffer = [] | ||||
|             if await self._load_next_chunk(): | ||||
|                 self.left = len(self.buffer) | ||||
| 
 | ||||
|         if not self.buffer: | ||||
|             raise StopAsyncIteration | ||||
|  | @ -82,7 +85,7 @@ class RequestIter(abc.ABC): | |||
|         return result | ||||
| 
 | ||||
|     def __aiter__(self): | ||||
|         self.buffer = () | ||||
|         self.buffer = None | ||||
|         self.index = 0 | ||||
|         self.last_load = 0 | ||||
|         self.left = self.limit | ||||
|  | @ -113,7 +116,12 @@ class RequestIter(abc.ABC): | |||
|     async def _load_next_chunk(self): | ||||
|         """ | ||||
|         Called when the next chunk is necessary. | ||||
|         It should *always* return a `list`. | ||||
| 
 | ||||
|         It should extend the `buffer` with new items. | ||||
| 
 | ||||
|         It should return ``True`` if it's the last chunk, | ||||
|         after which moment the method won't be called again | ||||
|         during the same iteration. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user