mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-29 06:57:50 +03:00 
			
		
		
		
	Use RequestIter in chat methods
This commit is contained in:
		
							parent
							
								
									4f647847e7
								
							
						
					
					
						commit
						40ded93c7c
					
				|  | @ -1,19 +1,202 @@ | |||
| import itertools | ||||
| import sys | ||||
| 
 | ||||
| from async_generator import async_generator, yield_ | ||||
| 
 | ||||
| from .users import UserMethods | ||||
| from .. import utils, helpers | ||||
| from .. import utils | ||||
| from ..requestiter import RequestIter | ||||
| from ..tl import types, functions, custom | ||||
| 
 | ||||
| 
 | ||||
| class _ParticipantsIter(RequestIter): | ||||
|     async def _init(self, entity, filter, search, aggressive): | ||||
|         if isinstance(filter, type): | ||||
|             if filter in (types.ChannelParticipantsBanned, | ||||
|                           types.ChannelParticipantsKicked, | ||||
|                           types.ChannelParticipantsSearch, | ||||
|                           types.ChannelParticipantsContacts): | ||||
|                 # These require a `q` parameter (support types for convenience) | ||||
|                 filter = filter('') | ||||
|             else: | ||||
|                 filter = filter() | ||||
| 
 | ||||
|         entity = await self.client.get_input_entity(entity) | ||||
|         if search and (filter | ||||
|                        or not isinstance(entity, types.InputPeerChannel)): | ||||
|             # We need to 'search' ourselves unless we have a PeerChannel | ||||
|             search = search.lower() | ||||
| 
 | ||||
|             self.filter_entity = lambda ent: ( | ||||
|                 search in utils.get_display_name(ent).lower() or | ||||
|                 search in (getattr(ent, 'username', '') or None).lower() | ||||
|             ) | ||||
|         else: | ||||
|             self.filter_entity = lambda ent: True | ||||
| 
 | ||||
|         if isinstance(entity, types.InputPeerChannel): | ||||
|             self.total = (await self.client( | ||||
|                 functions.channels.GetFullChannelRequest(entity) | ||||
|             )).full_chat.participants_count | ||||
| 
 | ||||
|             if self.limit == 0: | ||||
|                 raise StopAsyncIteration | ||||
| 
 | ||||
|             self.seen = set() | ||||
|             if aggressive and not filter: | ||||
|                 self.requests = [functions.channels.GetParticipantsRequest( | ||||
|                     channel=entity, | ||||
|                     filter=types.ChannelParticipantsSearch(x), | ||||
|                     offset=0, | ||||
|                     limit=200, | ||||
|                     hash=0 | ||||
|                 ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] | ||||
|             else: | ||||
|                 self.requests = [functions.channels.GetParticipantsRequest( | ||||
|                     channel=entity, | ||||
|                     filter=filter or types.ChannelParticipantsSearch(search), | ||||
|                     offset=0, | ||||
|                     limit=200, | ||||
|                     hash=0 | ||||
|                 )] | ||||
| 
 | ||||
|         elif isinstance(entity, types.InputPeerChat): | ||||
|             full = await self.client( | ||||
|                 functions.messages.GetFullChatRequest(entity.chat_id)) | ||||
|             if not isinstance( | ||||
|                     full.full_chat.participants, types.ChatParticipants): | ||||
|                 # ChatParticipantsForbidden won't have ``.participants`` | ||||
|                 self.total = 0 | ||||
|                 raise StopAsyncIteration | ||||
| 
 | ||||
|             self.total = len(full.full_chat.participants.participants) | ||||
| 
 | ||||
|             result = [] | ||||
|             users = {user.id: user for user in full.users} | ||||
|             for participant in full.full_chat.participants.participants: | ||||
|                 user = users[participant.user_id] | ||||
|                 if not self.filter_entity(user): | ||||
|                     continue | ||||
| 
 | ||||
|                 user = users[participant.user_id] | ||||
|                 user.participant = participant | ||||
|                 result.append(user) | ||||
| 
 | ||||
|             self.left = len(result) | ||||
|             self.buffer = result | ||||
|         else: | ||||
|             result = [] | ||||
|             self.total = 1 | ||||
|             if self.limit != 0: | ||||
|                 user = await self.client.get_entity(entity) | ||||
|                 if self.filter_entity(user): | ||||
|                     user.participant = None | ||||
|                     result.append(user) | ||||
| 
 | ||||
|             self.left = len(result) | ||||
|             self.buffer = result | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
|         if not self.requests: | ||||
|             return result | ||||
| 
 | ||||
|         # Only care about the limit for the first request | ||||
|         # (small amount of people, won't be aggressive). | ||||
|         # | ||||
|         # Most people won't care about getting exactly 12,345 | ||||
|         # members so it doesn't really matter not to be 100% | ||||
|         # precise with being out of the offset/limit here. | ||||
|         self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) | ||||
|         if self.requests[0].offset > self.limit: | ||||
|             return result | ||||
| 
 | ||||
|         results = await self.client(self.requests) | ||||
|         for i in reversed(range(len(self.requests))): | ||||
|             participants = results[i] | ||||
|             if not participants.users: | ||||
|                 self.requests.pop(i) | ||||
|                 continue | ||||
| 
 | ||||
|             self.requests[i].offset += len(participants.participants) | ||||
|             users = {user.id: user for user in participants.users} | ||||
|             for participant in participants.participants: | ||||
|                 user = users[participant.user_id] | ||||
|                 if not self.filter_entity(user) or user.id in self.seen: | ||||
|                     continue | ||||
| 
 | ||||
|                 self.seen.add(participant.user_id) | ||||
|                 user = users[participant.user_id] | ||||
|                 user.participant = participant | ||||
|                 result.append(user) | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
| 
 | ||||
| class _AdminLogIter(RequestIter): | ||||
|     async def _init( | ||||
|             self, entity, admins, search, min_id, max_id, | ||||
|             join, leave, invite, restrict, unrestrict, ban, unban, | ||||
|             promote, demote, info, settings, pinned, edit, delete | ||||
|     ): | ||||
|         if any((join, leave, invite, restrict, unrestrict, ban, unban, | ||||
|                 promote, demote, info, settings, pinned, edit, delete)): | ||||
|             events_filter = types.ChannelAdminLogEventsFilter( | ||||
|                 join=join, leave=leave, invite=invite, ban=restrict, | ||||
|                 unban=unrestrict, kick=ban, unkick=unban, promote=promote, | ||||
|                 demote=demote, info=info, settings=settings, pinned=pinned, | ||||
|                 edit=edit, delete=delete | ||||
|             ) | ||||
|         else: | ||||
|             events_filter = None | ||||
| 
 | ||||
|         self.entity = await self.client.get_input_entity(entity) | ||||
| 
 | ||||
|         admin_list = [] | ||||
|         if admins: | ||||
|             if not utils.is_list_like(admins): | ||||
|                 admins = (admins,) | ||||
| 
 | ||||
|             for admin in admins: | ||||
|                 admin_list.append(await self.client.get_input_entity(admin)) | ||||
| 
 | ||||
|         self.request = functions.channels.GetAdminLogRequest( | ||||
|             self.entity, q=search or '', min_id=min_id, max_id=max_id, | ||||
|             limit=0, events_filter=events_filter, admins=admin_list or None | ||||
|         ) | ||||
| 
 | ||||
|     async def _load_next_chunk(self): | ||||
|         result = [] | ||||
|         self.request.limit = min(self.left, 100) | ||||
|         r = await self.client(self.request) | ||||
|         entities = {utils.get_peer_id(x): x | ||||
|                     for x in itertools.chain(r.users, r.chats)} | ||||
| 
 | ||||
|         self.request.max_id = min((e.id for e in r.events), default=0) | ||||
|         for ev in r.events: | ||||
|             if isinstance(ev.action, | ||||
|                           types.ChannelAdminLogEventActionEditMessage): | ||||
|                 ev.action.prev_message._finish_init( | ||||
|                     self.client, entities, self.entity) | ||||
| 
 | ||||
|                 ev.action.new_message._finish_init( | ||||
|                     self.client, entities, self.entity) | ||||
| 
 | ||||
|             elif isinstance(ev.action, | ||||
|                             types.ChannelAdminLogEventActionDeleteMessage): | ||||
|                 ev.action.message._finish_init( | ||||
|                     self.client, entities, self.entity) | ||||
| 
 | ||||
|             result.append(custom.AdminLogEvent(ev, entities)) | ||||
| 
 | ||||
|         if len(r.events) < self.request.limit: | ||||
|             self.left = len(result) | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
| 
 | ||||
| class ChatMethods(UserMethods): | ||||
| 
 | ||||
|     # region Public methods | ||||
| 
 | ||||
|     @async_generator | ||||
|     async def iter_participants( | ||||
|     def iter_participants( | ||||
|             self, entity, limit=None, *, search='', | ||||
|             filter=None, aggressive=False, _total=None): | ||||
|         """ | ||||
|  | @ -62,138 +245,23 @@ class ChatMethods(UserMethods): | |||
|             matched :tl:`ChannelParticipant` type for channels/megagroups | ||||
|             or :tl:`ChatParticipants` for normal chats. | ||||
|         """ | ||||
|         if isinstance(filter, type): | ||||
|             if filter in (types.ChannelParticipantsBanned, | ||||
|                           types.ChannelParticipantsKicked, | ||||
|                           types.ChannelParticipantsSearch, | ||||
|                           types.ChannelParticipantsContacts): | ||||
|                 # These require a `q` parameter (support types for convenience) | ||||
|                 filter = filter('') | ||||
|             else: | ||||
|                 filter = filter() | ||||
| 
 | ||||
|         entity = await self.get_input_entity(entity) | ||||
|         if search and (filter | ||||
|                        or not isinstance(entity, types.InputPeerChannel)): | ||||
|             # We need to 'search' ourselves unless we have a PeerChannel | ||||
|             search = search.lower() | ||||
| 
 | ||||
|             def filter_entity(ent): | ||||
|                 return search in utils.get_display_name(ent).lower() or\ | ||||
|                        search in (getattr(ent, 'username', '') or None).lower() | ||||
|         else: | ||||
|             def filter_entity(ent): | ||||
|                 return True | ||||
| 
 | ||||
|         limit = float('inf') if limit is None else int(limit) | ||||
|         if isinstance(entity, types.InputPeerChannel): | ||||
|             if _total: | ||||
|                 _total[0] = (await self( | ||||
|                     functions.channels.GetFullChannelRequest(entity) | ||||
|                 )).full_chat.participants_count | ||||
| 
 | ||||
|             if limit == 0: | ||||
|                 return | ||||
| 
 | ||||
|             seen = set() | ||||
|             if aggressive and not filter: | ||||
|                 requests = [functions.channels.GetParticipantsRequest( | ||||
|                     channel=entity, | ||||
|                     filter=types.ChannelParticipantsSearch(x), | ||||
|                     offset=0, | ||||
|                     limit=200, | ||||
|                     hash=0 | ||||
|                 ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] | ||||
|             else: | ||||
|                 requests = [functions.channels.GetParticipantsRequest( | ||||
|                     channel=entity, | ||||
|                     filter=filter or types.ChannelParticipantsSearch(search), | ||||
|                     offset=0, | ||||
|                     limit=200, | ||||
|                     hash=0 | ||||
|                 )] | ||||
| 
 | ||||
|             while requests: | ||||
|                 # Only care about the limit for the first request | ||||
|                 # (small amount of people, won't be aggressive). | ||||
|                 # | ||||
|                 # Most people won't care about getting exactly 12,345 | ||||
|                 # members so it doesn't really matter not to be 100% | ||||
|                 # precise with being out of the offset/limit here. | ||||
|                 requests[0].limit = min(limit - requests[0].offset, 200) | ||||
|                 if requests[0].offset > limit: | ||||
|                     break | ||||
| 
 | ||||
|                 results = await self(requests) | ||||
|                 for i in reversed(range(len(requests))): | ||||
|                     participants = results[i] | ||||
|                     if not participants.users: | ||||
|                         requests.pop(i) | ||||
|                     else: | ||||
|                         requests[i].offset += len(participants.participants) | ||||
|                         users = {user.id: user for user in participants.users} | ||||
|                         for participant in participants.participants: | ||||
|                             user = users[participant.user_id] | ||||
|                             if not filter_entity(user) or user.id in seen: | ||||
|                                 continue | ||||
| 
 | ||||
|                             seen.add(participant.user_id) | ||||
|                             user = users[participant.user_id] | ||||
|                             user.participant = participant | ||||
|                             await yield_(user) | ||||
|                             if len(seen) >= limit: | ||||
|                                 return | ||||
| 
 | ||||
|         elif isinstance(entity, types.InputPeerChat): | ||||
|             full = await self( | ||||
|                 functions.messages.GetFullChatRequest(entity.chat_id)) | ||||
|             if not isinstance( | ||||
|                     full.full_chat.participants, types.ChatParticipants): | ||||
|                 # ChatParticipantsForbidden won't have ``.participants`` | ||||
|                 if _total: | ||||
|                     _total[0] = 0 | ||||
|                 return | ||||
| 
 | ||||
|             if _total: | ||||
|                 _total[0] = len(full.full_chat.participants.participants) | ||||
| 
 | ||||
|             have = 0 | ||||
|             users = {user.id: user for user in full.users} | ||||
|             for participant in full.full_chat.participants.participants: | ||||
|                 user = users[participant.user_id] | ||||
|                 if not filter_entity(user): | ||||
|                     continue | ||||
|                 have += 1 | ||||
|                 if have > limit: | ||||
|                     break | ||||
|                 else: | ||||
|                     user = users[participant.user_id] | ||||
|                     user.participant = participant | ||||
|                     await yield_(user) | ||||
|         else: | ||||
|             if _total: | ||||
|                 _total[0] = 1 | ||||
|             if limit != 0: | ||||
|                 user = await self.get_entity(entity) | ||||
|                 if filter_entity(user): | ||||
|                     user.participant = None | ||||
|                     await yield_(user) | ||||
|         return _ParticipantsIter( | ||||
|             self, | ||||
|             limit, | ||||
|             entity=entity, | ||||
|             filter=filter, | ||||
|             search=search, | ||||
|             aggressive=aggressive | ||||
|         ) | ||||
| 
 | ||||
|     async def get_participants(self, *args, **kwargs): | ||||
|         """ | ||||
|         Same as `iter_participants`, but returns a | ||||
|         `TotalList <telethon.helpers.TotalList>` instead. | ||||
|         """ | ||||
|         total = [0] | ||||
|         kwargs['_total'] = total | ||||
|         participants = helpers.TotalList() | ||||
|         async for x in self.iter_participants(*args, **kwargs): | ||||
|             participants.append(x) | ||||
|         participants.total = total[0] | ||||
|         return participants | ||||
|         return await self.iter_participants(*args, **kwargs).collect() | ||||
| 
 | ||||
|     @async_generator | ||||
|     async def iter_admin_log( | ||||
|     def iter_admin_log( | ||||
|             self, entity, limit=None, *, max_id=0, min_id=0, search=None, | ||||
|             admins=None, join=None, leave=None, invite=None, restrict=None, | ||||
|             unrestrict=None, ban=None, unban=None, promote=None, demote=None, | ||||
|  | @ -285,66 +353,34 @@ class ChatMethods(UserMethods): | |||
|         Yields: | ||||
|             Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`. | ||||
|         """ | ||||
|         if limit is None: | ||||
|             limit = sys.maxsize | ||||
|         elif limit <= 0: | ||||
|             return | ||||
| 
 | ||||
|         if any((join, leave, invite, restrict, unrestrict, ban, unban, | ||||
|                 promote, demote, info, settings, pinned, edit, delete)): | ||||
|             events_filter = types.ChannelAdminLogEventsFilter( | ||||
|                 join=join, leave=leave, invite=invite, ban=restrict, | ||||
|                 unban=unrestrict, kick=ban, unkick=unban, promote=promote, | ||||
|                 demote=demote, info=info, settings=settings, pinned=pinned, | ||||
|                 edit=edit, delete=delete | ||||
|             ) | ||||
|         else: | ||||
|             events_filter = None | ||||
| 
 | ||||
|         entity = await self.get_input_entity(entity) | ||||
| 
 | ||||
|         admin_list = [] | ||||
|         if admins: | ||||
|             if not utils.is_list_like(admins): | ||||
|                 admins = (admins,) | ||||
| 
 | ||||
|             for admin in admins: | ||||
|                 admin_list.append(await self.get_input_entity(admin)) | ||||
| 
 | ||||
|         request = functions.channels.GetAdminLogRequest( | ||||
|             entity, q=search or '', min_id=min_id, max_id=max_id, | ||||
|             limit=0, events_filter=events_filter, admins=admin_list or None | ||||
|         return _AdminLogIter( | ||||
|             self, | ||||
|             limit, | ||||
|             entity=entity, | ||||
|             admins=admins, | ||||
|             search=search, | ||||
|             min_id=min_id, | ||||
|             max_id=max_id, | ||||
|             join=join, | ||||
|             leave=leave, | ||||
|             invite=invite, | ||||
|             restrict=restrict, | ||||
|             unrestrict=unrestrict, | ||||
|             ban=ban, | ||||
|             unban=unban, | ||||
|             promote=promote, | ||||
|             demote=demote, | ||||
|             info=info, | ||||
|             settings=settings, | ||||
|             pinned=pinned, | ||||
|             edit=edit, | ||||
|             delete=delete | ||||
|         ) | ||||
|         while limit > 0: | ||||
|             request.limit = min(limit, 100) | ||||
|             result = await self(request) | ||||
|             limit -= len(result.events) | ||||
|             entities = {utils.get_peer_id(x): x | ||||
|                         for x in itertools.chain(result.users, result.chats)} | ||||
| 
 | ||||
|             request.max_id = min((e.id for e in result.events), default=0) | ||||
|             for ev in result.events: | ||||
|                 if isinstance(ev.action, | ||||
|                               types.ChannelAdminLogEventActionEditMessage): | ||||
|                     ev.action.prev_message._finish_init(self, entities, entity) | ||||
|                     ev.action.new_message._finish_init(self, entities, entity) | ||||
| 
 | ||||
|                 elif isinstance(ev.action, | ||||
|                                 types.ChannelAdminLogEventActionDeleteMessage): | ||||
|                     ev.action.message._finish_init(self, entities, entity) | ||||
| 
 | ||||
|                 await yield_(custom.AdminLogEvent(ev, entities)) | ||||
| 
 | ||||
|             if len(result.events) < request.limit: | ||||
|                 break | ||||
| 
 | ||||
|     async def get_admin_log(self, *args, **kwargs): | ||||
|         """ | ||||
|         Same as `iter_admin_log`, but returns a ``list`` instead. | ||||
|         """ | ||||
|         admin_log = [] | ||||
|         async for x in self.iter_admin_log(*args, **kwargs): | ||||
|             admin_log.append(x) | ||||
|         return admin_log | ||||
|         return await self.iter_admin_log(*args, **kwargs).collect() | ||||
| 
 | ||||
|     # endregion | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user