From 40ded93c7c4bfe9aa377a6f3b4ca90e64bf11d28 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 27 Feb 2019 11:12:05 +0100 Subject: [PATCH] Use RequestIter in chat methods --- telethon/client/chats.py | 408 +++++++++++++++++++++------------------ 1 file changed, 222 insertions(+), 186 deletions(-) diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 6578c785..7afe4655 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -1,19 +1,202 @@ import itertools -import sys - -from async_generator import async_generator, yield_ from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _ParticipantsIter(RequestIter): + async def _init(self, entity, filter, search, aggressive): + if isinstance(filter, type): + if filter in (types.ChannelParticipantsBanned, + types.ChannelParticipantsKicked, + types.ChannelParticipantsSearch, + types.ChannelParticipantsContacts): + # These require a `q` parameter (support types for convenience) + filter = filter('') + else: + filter = filter() + + entity = await self.client.get_input_entity(entity) + if search and (filter + or not isinstance(entity, types.InputPeerChannel)): + # We need to 'search' ourselves unless we have a PeerChannel + search = search.lower() + + self.filter_entity = lambda ent: ( + search in utils.get_display_name(ent).lower() or + search in (getattr(ent, 'username', '') or None).lower() + ) + else: + self.filter_entity = lambda ent: True + + if isinstance(entity, types.InputPeerChannel): + self.total = (await self.client( + functions.channels.GetFullChannelRequest(entity) + )).full_chat.participants_count + + if self.limit == 0: + raise StopAsyncIteration + + self.seen = set() + if aggressive and not filter: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=types.ChannelParticipantsSearch(x), + offset=0, + limit=200, + hash=0 + ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] + else: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=filter or types.ChannelParticipantsSearch(search), + offset=0, + limit=200, + hash=0 + )] + + elif isinstance(entity, types.InputPeerChat): + full = await self.client( + functions.messages.GetFullChatRequest(entity.chat_id)) + if not isinstance( + full.full_chat.participants, types.ChatParticipants): + # ChatParticipantsForbidden won't have ``.participants`` + self.total = 0 + raise StopAsyncIteration + + self.total = len(full.full_chat.participants.participants) + + result = [] + users = {user.id: user for user in full.users} + for participant in full.full_chat.participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user): + continue + + user = users[participant.user_id] + user.participant = participant + result.append(user) + + self.left = len(result) + self.buffer = result + else: + result = [] + self.total = 1 + if self.limit != 0: + user = await self.client.get_entity(entity) + if self.filter_entity(user): + user.participant = None + result.append(user) + + self.left = len(result) + self.buffer = result + + async def _load_next_chunk(self): + result = [] + if not self.requests: + return result + + # Only care about the limit for the first request + # (small amount of people, won't be aggressive). + # + # Most people won't care about getting exactly 12,345 + # members so it doesn't really matter not to be 100% + # precise with being out of the offset/limit here. + self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) + if self.requests[0].offset > self.limit: + return result + + results = await self.client(self.requests) + for i in reversed(range(len(self.requests))): + participants = results[i] + if not participants.users: + self.requests.pop(i) + continue + + self.requests[i].offset += len(participants.participants) + users = {user.id: user for user in participants.users} + for participant in participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user) or user.id in self.seen: + continue + + self.seen.add(participant.user_id) + user = users[participant.user_id] + user.participant = participant + result.append(user) + + return result + + +class _AdminLogIter(RequestIter): + async def _init( + self, entity, admins, search, min_id, max_id, + join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete + ): + if any((join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete)): + events_filter = types.ChannelAdminLogEventsFilter( + join=join, leave=leave, invite=invite, ban=restrict, + unban=unrestrict, kick=ban, unkick=unban, promote=promote, + demote=demote, info=info, settings=settings, pinned=pinned, + edit=edit, delete=delete + ) + else: + events_filter = None + + self.entity = await self.client.get_input_entity(entity) + + admin_list = [] + if admins: + if not utils.is_list_like(admins): + admins = (admins,) + + for admin in admins: + admin_list.append(await self.client.get_input_entity(admin)) + + self.request = functions.channels.GetAdminLogRequest( + self.entity, q=search or '', min_id=min_id, max_id=max_id, + limit=0, events_filter=events_filter, admins=admin_list or None + ) + + async def _load_next_chunk(self): + result = [] + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + self.request.max_id = min((e.id for e in r.events), default=0) + for ev in r.events: + if isinstance(ev.action, + types.ChannelAdminLogEventActionEditMessage): + ev.action.prev_message._finish_init( + self.client, entities, self.entity) + + ev.action.new_message._finish_init( + self.client, entities, self.entity) + + elif isinstance(ev.action, + types.ChannelAdminLogEventActionDeleteMessage): + ev.action.message._finish_init( + self.client, entities, self.entity) + + result.append(custom.AdminLogEvent(ev, entities)) + + if len(r.events) < self.request.limit: + self.left = len(result) + + return result + + class ChatMethods(UserMethods): # region Public methods - @async_generator - async def iter_participants( + def iter_participants( self, entity, limit=None, *, search='', filter=None, aggressive=False, _total=None): """ @@ -62,138 +245,23 @@ class ChatMethods(UserMethods): matched :tl:`ChannelParticipant` type for channels/megagroups or :tl:`ChatParticipants` for normal chats. """ - if isinstance(filter, type): - if filter in (types.ChannelParticipantsBanned, - types.ChannelParticipantsKicked, - types.ChannelParticipantsSearch, - types.ChannelParticipantsContacts): - # These require a `q` parameter (support types for convenience) - filter = filter('') - else: - filter = filter() - - entity = await self.get_input_entity(entity) - if search and (filter - or not isinstance(entity, types.InputPeerChannel)): - # We need to 'search' ourselves unless we have a PeerChannel - search = search.lower() - - def filter_entity(ent): - return search in utils.get_display_name(ent).lower() or\ - search in (getattr(ent, 'username', '') or None).lower() - else: - def filter_entity(ent): - return True - - limit = float('inf') if limit is None else int(limit) - if isinstance(entity, types.InputPeerChannel): - if _total: - _total[0] = (await self( - functions.channels.GetFullChannelRequest(entity) - )).full_chat.participants_count - - if limit == 0: - return - - seen = set() - if aggressive and not filter: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=types.ChannelParticipantsSearch(x), - offset=0, - limit=200, - hash=0 - ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] - else: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=filter or types.ChannelParticipantsSearch(search), - offset=0, - limit=200, - hash=0 - )] - - while requests: - # Only care about the limit for the first request - # (small amount of people, won't be aggressive). - # - # Most people won't care about getting exactly 12,345 - # members so it doesn't really matter not to be 100% - # precise with being out of the offset/limit here. - requests[0].limit = min(limit - requests[0].offset, 200) - if requests[0].offset > limit: - break - - results = await self(requests) - for i in reversed(range(len(requests))): - participants = results[i] - if not participants.users: - requests.pop(i) - else: - requests[i].offset += len(participants.participants) - users = {user.id: user for user in participants.users} - for participant in participants.participants: - user = users[participant.user_id] - if not filter_entity(user) or user.id in seen: - continue - - seen.add(participant.user_id) - user = users[participant.user_id] - user.participant = participant - await yield_(user) - if len(seen) >= limit: - return - - elif isinstance(entity, types.InputPeerChat): - full = await self( - functions.messages.GetFullChatRequest(entity.chat_id)) - if not isinstance( - full.full_chat.participants, types.ChatParticipants): - # ChatParticipantsForbidden won't have ``.participants`` - if _total: - _total[0] = 0 - return - - if _total: - _total[0] = len(full.full_chat.participants.participants) - - have = 0 - users = {user.id: user for user in full.users} - for participant in full.full_chat.participants.participants: - user = users[participant.user_id] - if not filter_entity(user): - continue - have += 1 - if have > limit: - break - else: - user = users[participant.user_id] - user.participant = participant - await yield_(user) - else: - if _total: - _total[0] = 1 - if limit != 0: - user = await self.get_entity(entity) - if filter_entity(user): - user.participant = None - await yield_(user) + return _ParticipantsIter( + self, + limit, + entity=entity, + filter=filter, + search=search, + aggressive=aggressive + ) async def get_participants(self, *args, **kwargs): """ Same as `iter_participants`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - participants = helpers.TotalList() - async for x in self.iter_participants(*args, **kwargs): - participants.append(x) - participants.total = total[0] - return participants + return await self.iter_participants(*args, **kwargs).collect() - @async_generator - async def iter_admin_log( + def iter_admin_log( self, entity, limit=None, *, max_id=0, min_id=0, search=None, admins=None, join=None, leave=None, invite=None, restrict=None, unrestrict=None, ban=None, unban=None, promote=None, demote=None, @@ -285,66 +353,34 @@ class ChatMethods(UserMethods): Yields: Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`. """ - if limit is None: - limit = sys.maxsize - elif limit <= 0: - return - - if any((join, leave, invite, restrict, unrestrict, ban, unban, - promote, demote, info, settings, pinned, edit, delete)): - events_filter = types.ChannelAdminLogEventsFilter( - join=join, leave=leave, invite=invite, ban=restrict, - unban=unrestrict, kick=ban, unkick=unban, promote=promote, - demote=demote, info=info, settings=settings, pinned=pinned, - edit=edit, delete=delete - ) - else: - events_filter = None - - entity = await self.get_input_entity(entity) - - admin_list = [] - if admins: - if not utils.is_list_like(admins): - admins = (admins,) - - for admin in admins: - admin_list.append(await self.get_input_entity(admin)) - - request = functions.channels.GetAdminLogRequest( - entity, q=search or '', min_id=min_id, max_id=max_id, - limit=0, events_filter=events_filter, admins=admin_list or None + return _AdminLogIter( + self, + limit, + entity=entity, + admins=admins, + search=search, + min_id=min_id, + max_id=max_id, + join=join, + leave=leave, + invite=invite, + restrict=restrict, + unrestrict=unrestrict, + ban=ban, + unban=unban, + promote=promote, + demote=demote, + info=info, + settings=settings, + pinned=pinned, + edit=edit, + delete=delete ) - while limit > 0: - request.limit = min(limit, 100) - result = await self(request) - limit -= len(result.events) - entities = {utils.get_peer_id(x): x - for x in itertools.chain(result.users, result.chats)} - - request.max_id = min((e.id for e in result.events), default=0) - for ev in result.events: - if isinstance(ev.action, - types.ChannelAdminLogEventActionEditMessage): - ev.action.prev_message._finish_init(self, entities, entity) - ev.action.new_message._finish_init(self, entities, entity) - - elif isinstance(ev.action, - types.ChannelAdminLogEventActionDeleteMessage): - ev.action.message._finish_init(self, entities, entity) - - await yield_(custom.AdminLogEvent(ev, entities)) - - if len(result.events) < request.limit: - break async def get_admin_log(self, *args, **kwargs): """ Same as `iter_admin_log`, but returns a ``list`` instead. """ - admin_log = [] - async for x in self.iter_admin_log(*args, **kwargs): - admin_log.append(x) - return admin_log + return await self.iter_admin_log(*args, **kwargs).collect() # endregion