Use RequestIter in chat methods

This commit is contained in:
Lonami Exo 2019-02-27 11:12:05 +01:00
parent 4f647847e7
commit 40ded93c7c

View File

@ -1,19 +1,202 @@
import itertools
import sys
from async_generator import async_generator, yield_
from .users import UserMethods
from .. import utils, helpers
from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
class _ParticipantsIter(RequestIter):
async def _init(self, entity, filter, search, aggressive):
if isinstance(filter, type):
if filter in (types.ChannelParticipantsBanned,
types.ChannelParticipantsKicked,
types.ChannelParticipantsSearch,
types.ChannelParticipantsContacts):
# These require a `q` parameter (support types for convenience)
filter = filter('')
else:
filter = filter()
entity = await self.client.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
self.filter_entity = lambda ent: (
search in utils.get_display_name(ent).lower() or
search in (getattr(ent, 'username', '') or None).lower()
)
else:
self.filter_entity = lambda ent: True
if isinstance(entity, types.InputPeerChannel):
self.total = (await self.client(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if self.limit == 0:
raise StopAsyncIteration
self.seen = set()
if aggressive and not filter:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
elif isinstance(entity, types.InputPeerChat):
full = await self.client(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
self.total = 0
raise StopAsyncIteration
self.total = len(full.full_chat.participants.participants)
result = []
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user):
continue
user = users[participant.user_id]
user.participant = participant
result.append(user)
self.left = len(result)
self.buffer = result
else:
result = []
self.total = 1
if self.limit != 0:
user = await self.client.get_entity(entity)
if self.filter_entity(user):
user.participant = None
result.append(user)
self.left = len(result)
self.buffer = result
async def _load_next_chunk(self):
result = []
if not self.requests:
return result
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
if self.requests[0].offset > self.limit:
return result
results = await self.client(self.requests)
for i in reversed(range(len(self.requests))):
participants = results[i]
if not participants.users:
self.requests.pop(i)
continue
self.requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user) or user.id in self.seen:
continue
self.seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
result.append(user)
return result
class _AdminLogIter(RequestIter):
async def _init(
self, entity, admins, search, min_id, max_id,
join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete
):
if any((join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete)):
events_filter = types.ChannelAdminLogEventsFilter(
join=join, leave=leave, invite=invite, ban=restrict,
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
demote=demote, info=info, settings=settings, pinned=pinned,
edit=edit, delete=delete
)
else:
events_filter = None
self.entity = await self.client.get_input_entity(entity)
admin_list = []
if admins:
if not utils.is_list_like(admins):
admins = (admins,)
for admin in admins:
admin_list.append(await self.client.get_input_entity(admin))
self.request = functions.channels.GetAdminLogRequest(
self.entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
)
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
self.request.max_id = min((e.id for e in r.events), default=0)
for ev in r.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(
self.client, entities, self.entity)
ev.action.new_message._finish_init(
self.client, entities, self.entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(
self.client, entities, self.entity)
result.append(custom.AdminLogEvent(ev, entities))
if len(r.events) < self.request.limit:
self.left = len(result)
return result
class ChatMethods(UserMethods):
# region Public methods
@async_generator
async def iter_participants(
def iter_participants(
self, entity, limit=None, *, search='',
filter=None, aggressive=False, _total=None):
"""
@ -62,138 +245,23 @@ class ChatMethods(UserMethods):
matched :tl:`ChannelParticipant` type for channels/megagroups
or :tl:`ChatParticipants` for normal chats.
"""
if isinstance(filter, type):
if filter in (types.ChannelParticipantsBanned,
types.ChannelParticipantsKicked,
types.ChannelParticipantsSearch,
types.ChannelParticipantsContacts):
# These require a `q` parameter (support types for convenience)
filter = filter('')
else:
filter = filter()
entity = await self.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
def filter_entity(ent):
return search in utils.get_display_name(ent).lower() or\
search in (getattr(ent, 'username', '') or None).lower()
else:
def filter_entity(ent):
return True
limit = float('inf') if limit is None else int(limit)
if isinstance(entity, types.InputPeerChannel):
if _total:
_total[0] = (await self(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if limit == 0:
return
seen = set()
if aggressive and not filter:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
while requests:
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
requests[0].limit = min(limit - requests[0].offset, 200)
if requests[0].offset > limit:
break
results = await self(requests)
for i in reversed(range(len(requests))):
participants = results[i]
if not participants.users:
requests.pop(i)
else:
requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not filter_entity(user) or user.id in seen:
continue
seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
await yield_(user)
if len(seen) >= limit:
return
elif isinstance(entity, types.InputPeerChat):
full = await self(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
if _total:
_total[0] = 0
return
if _total:
_total[0] = len(full.full_chat.participants.participants)
have = 0
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not filter_entity(user):
continue
have += 1
if have > limit:
break
else:
user = users[participant.user_id]
user.participant = participant
await yield_(user)
else:
if _total:
_total[0] = 1
if limit != 0:
user = await self.get_entity(entity)
if filter_entity(user):
user.participant = None
await yield_(user)
return _ParticipantsIter(
self,
limit,
entity=entity,
filter=filter,
search=search,
aggressive=aggressive
)
async def get_participants(self, *args, **kwargs):
"""
Same as `iter_participants`, but returns a
`TotalList <telethon.helpers.TotalList>` instead.
"""
total = [0]
kwargs['_total'] = total
participants = helpers.TotalList()
async for x in self.iter_participants(*args, **kwargs):
participants.append(x)
participants.total = total[0]
return participants
return await self.iter_participants(*args, **kwargs).collect()
@async_generator
async def iter_admin_log(
def iter_admin_log(
self, entity, limit=None, *, max_id=0, min_id=0, search=None,
admins=None, join=None, leave=None, invite=None, restrict=None,
unrestrict=None, ban=None, unban=None, promote=None, demote=None,
@ -285,66 +353,34 @@ class ChatMethods(UserMethods):
Yields:
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
"""
if limit is None:
limit = sys.maxsize
elif limit <= 0:
return
if any((join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete)):
events_filter = types.ChannelAdminLogEventsFilter(
join=join, leave=leave, invite=invite, ban=restrict,
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
demote=demote, info=info, settings=settings, pinned=pinned,
edit=edit, delete=delete
)
else:
events_filter = None
entity = await self.get_input_entity(entity)
admin_list = []
if admins:
if not utils.is_list_like(admins):
admins = (admins,)
for admin in admins:
admin_list.append(await self.get_input_entity(admin))
request = functions.channels.GetAdminLogRequest(
entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
return _AdminLogIter(
self,
limit,
entity=entity,
admins=admins,
search=search,
min_id=min_id,
max_id=max_id,
join=join,
leave=leave,
invite=invite,
restrict=restrict,
unrestrict=unrestrict,
ban=ban,
unban=unban,
promote=promote,
demote=demote,
info=info,
settings=settings,
pinned=pinned,
edit=edit,
delete=delete
)
while limit > 0:
request.limit = min(limit, 100)
result = await self(request)
limit -= len(result.events)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(result.users, result.chats)}
request.max_id = min((e.id for e in result.events), default=0)
for ev in result.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(self, entities, entity)
ev.action.new_message._finish_init(self, entities, entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(self, entities, entity)
await yield_(custom.AdminLogEvent(ev, entities))
if len(result.events) < request.limit:
break
async def get_admin_log(self, *args, **kwargs):
"""
Same as `iter_admin_log`, but returns a ``list`` instead.
"""
admin_log = []
async for x in self.iter_admin_log(*args, **kwargs):
admin_log.append(x)
return admin_log
return await self.iter_admin_log(*args, **kwargs).collect()
# endregion