Use RequestIter in chat methods

This commit is contained in:
Lonami Exo 2019-02-27 11:12:05 +01:00
parent 4f647847e7
commit 40ded93c7c

View File

@ -1,19 +1,202 @@
import itertools import itertools
import sys
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from .. import utils, helpers from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom from ..tl import types, functions, custom
class _ParticipantsIter(RequestIter):
async def _init(self, entity, filter, search, aggressive):
if isinstance(filter, type):
if filter in (types.ChannelParticipantsBanned,
types.ChannelParticipantsKicked,
types.ChannelParticipantsSearch,
types.ChannelParticipantsContacts):
# These require a `q` parameter (support types for convenience)
filter = filter('')
else:
filter = filter()
entity = await self.client.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
self.filter_entity = lambda ent: (
search in utils.get_display_name(ent).lower() or
search in (getattr(ent, 'username', '') or None).lower()
)
else:
self.filter_entity = lambda ent: True
if isinstance(entity, types.InputPeerChannel):
self.total = (await self.client(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if self.limit == 0:
raise StopAsyncIteration
self.seen = set()
if aggressive and not filter:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
elif isinstance(entity, types.InputPeerChat):
full = await self.client(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
self.total = 0
raise StopAsyncIteration
self.total = len(full.full_chat.participants.participants)
result = []
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user):
continue
user = users[participant.user_id]
user.participant = participant
result.append(user)
self.left = len(result)
self.buffer = result
else:
result = []
self.total = 1
if self.limit != 0:
user = await self.client.get_entity(entity)
if self.filter_entity(user):
user.participant = None
result.append(user)
self.left = len(result)
self.buffer = result
async def _load_next_chunk(self):
result = []
if not self.requests:
return result
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
if self.requests[0].offset > self.limit:
return result
results = await self.client(self.requests)
for i in reversed(range(len(self.requests))):
participants = results[i]
if not participants.users:
self.requests.pop(i)
continue
self.requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user) or user.id in self.seen:
continue
self.seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
result.append(user)
return result
class _AdminLogIter(RequestIter):
async def _init(
self, entity, admins, search, min_id, max_id,
join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete
):
if any((join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete)):
events_filter = types.ChannelAdminLogEventsFilter(
join=join, leave=leave, invite=invite, ban=restrict,
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
demote=demote, info=info, settings=settings, pinned=pinned,
edit=edit, delete=delete
)
else:
events_filter = None
self.entity = await self.client.get_input_entity(entity)
admin_list = []
if admins:
if not utils.is_list_like(admins):
admins = (admins,)
for admin in admins:
admin_list.append(await self.client.get_input_entity(admin))
self.request = functions.channels.GetAdminLogRequest(
self.entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
)
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
self.request.max_id = min((e.id for e in r.events), default=0)
for ev in r.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(
self.client, entities, self.entity)
ev.action.new_message._finish_init(
self.client, entities, self.entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(
self.client, entities, self.entity)
result.append(custom.AdminLogEvent(ev, entities))
if len(r.events) < self.request.limit:
self.left = len(result)
return result
class ChatMethods(UserMethods): class ChatMethods(UserMethods):
# region Public methods # region Public methods
@async_generator def iter_participants(
async def iter_participants(
self, entity, limit=None, *, search='', self, entity, limit=None, *, search='',
filter=None, aggressive=False, _total=None): filter=None, aggressive=False, _total=None):
""" """
@ -62,138 +245,23 @@ class ChatMethods(UserMethods):
matched :tl:`ChannelParticipant` type for channels/megagroups matched :tl:`ChannelParticipant` type for channels/megagroups
or :tl:`ChatParticipants` for normal chats. or :tl:`ChatParticipants` for normal chats.
""" """
if isinstance(filter, type): return _ParticipantsIter(
if filter in (types.ChannelParticipantsBanned, self,
types.ChannelParticipantsKicked, limit,
types.ChannelParticipantsSearch, entity=entity,
types.ChannelParticipantsContacts): filter=filter,
# These require a `q` parameter (support types for convenience) search=search,
filter = filter('') aggressive=aggressive
else: )
filter = filter()
entity = await self.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
def filter_entity(ent):
return search in utils.get_display_name(ent).lower() or\
search in (getattr(ent, 'username', '') or None).lower()
else:
def filter_entity(ent):
return True
limit = float('inf') if limit is None else int(limit)
if isinstance(entity, types.InputPeerChannel):
if _total:
_total[0] = (await self(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if limit == 0:
return
seen = set()
if aggressive and not filter:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
while requests:
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
requests[0].limit = min(limit - requests[0].offset, 200)
if requests[0].offset > limit:
break
results = await self(requests)
for i in reversed(range(len(requests))):
participants = results[i]
if not participants.users:
requests.pop(i)
else:
requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not filter_entity(user) or user.id in seen:
continue
seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
await yield_(user)
if len(seen) >= limit:
return
elif isinstance(entity, types.InputPeerChat):
full = await self(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
if _total:
_total[0] = 0
return
if _total:
_total[0] = len(full.full_chat.participants.participants)
have = 0
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not filter_entity(user):
continue
have += 1
if have > limit:
break
else:
user = users[participant.user_id]
user.participant = participant
await yield_(user)
else:
if _total:
_total[0] = 1
if limit != 0:
user = await self.get_entity(entity)
if filter_entity(user):
user.participant = None
await yield_(user)
async def get_participants(self, *args, **kwargs): async def get_participants(self, *args, **kwargs):
""" """
Same as `iter_participants`, but returns a Same as `iter_participants`, but returns a
`TotalList <telethon.helpers.TotalList>` instead. `TotalList <telethon.helpers.TotalList>` instead.
""" """
total = [0] return await self.iter_participants(*args, **kwargs).collect()
kwargs['_total'] = total
participants = helpers.TotalList()
async for x in self.iter_participants(*args, **kwargs):
participants.append(x)
participants.total = total[0]
return participants
@async_generator def iter_admin_log(
async def iter_admin_log(
self, entity, limit=None, *, max_id=0, min_id=0, search=None, self, entity, limit=None, *, max_id=0, min_id=0, search=None,
admins=None, join=None, leave=None, invite=None, restrict=None, admins=None, join=None, leave=None, invite=None, restrict=None,
unrestrict=None, ban=None, unban=None, promote=None, demote=None, unrestrict=None, ban=None, unban=None, promote=None, demote=None,
@ -285,66 +353,34 @@ class ChatMethods(UserMethods):
Yields: Yields:
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`. Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
""" """
if limit is None: return _AdminLogIter(
limit = sys.maxsize self,
elif limit <= 0: limit,
return entity=entity,
admins=admins,
if any((join, leave, invite, restrict, unrestrict, ban, unban, search=search,
promote, demote, info, settings, pinned, edit, delete)): min_id=min_id,
events_filter = types.ChannelAdminLogEventsFilter( max_id=max_id,
join=join, leave=leave, invite=invite, ban=restrict, join=join,
unban=unrestrict, kick=ban, unkick=unban, promote=promote, leave=leave,
demote=demote, info=info, settings=settings, pinned=pinned, invite=invite,
edit=edit, delete=delete restrict=restrict,
) unrestrict=unrestrict,
else: ban=ban,
events_filter = None unban=unban,
promote=promote,
entity = await self.get_input_entity(entity) demote=demote,
info=info,
admin_list = [] settings=settings,
if admins: pinned=pinned,
if not utils.is_list_like(admins): edit=edit,
admins = (admins,) delete=delete
for admin in admins:
admin_list.append(await self.get_input_entity(admin))
request = functions.channels.GetAdminLogRequest(
entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
) )
while limit > 0:
request.limit = min(limit, 100)
result = await self(request)
limit -= len(result.events)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(result.users, result.chats)}
request.max_id = min((e.id for e in result.events), default=0)
for ev in result.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(self, entities, entity)
ev.action.new_message._finish_init(self, entities, entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(self, entities, entity)
await yield_(custom.AdminLogEvent(ev, entities))
if len(result.events) < request.limit:
break
async def get_admin_log(self, *args, **kwargs): async def get_admin_log(self, *args, **kwargs):
""" """
Same as `iter_admin_log`, but returns a ``list`` instead. Same as `iter_admin_log`, but returns a ``list`` instead.
""" """
admin_log = [] return await self.iter_admin_log(*args, **kwargs).collect()
async for x in self.iter_admin_log(*args, **kwargs):
admin_log.append(x)
return admin_log
# endregion # endregion