mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-05 06:01:02 +03:00
Merge pull request #1115 from LonamiWebs/requestiter
Overhaul asynchronous generators
This commit is contained in:
commit
c99157ade2
|
@ -1,3 +1,2 @@
|
|||
pyaes
|
||||
rsa
|
||||
async_generator
|
||||
|
|
3
setup.py
3
setup.py
|
@ -214,8 +214,7 @@ def main():
|
|||
packages=find_packages(exclude=[
|
||||
'telethon_*', 'run_tests.py', 'try_telethon.py'
|
||||
]),
|
||||
install_requires=['pyaes', 'rsa',
|
||||
'async_generator'],
|
||||
install_requires=['pyaes', 'rsa'],
|
||||
extras_require={
|
||||
'cryptg': ['cryptg']
|
||||
}
|
||||
|
|
|
@ -1,19 +1,192 @@
|
|||
import itertools
|
||||
import sys
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils, helpers
|
||||
from .. import utils
|
||||
from ..requestiter import RequestIter
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
|
||||
class _ParticipantsIter(RequestIter):
|
||||
async def _init(self, entity, filter, search, aggressive):
|
||||
if isinstance(filter, type):
|
||||
if filter in (types.ChannelParticipantsBanned,
|
||||
types.ChannelParticipantsKicked,
|
||||
types.ChannelParticipantsSearch,
|
||||
types.ChannelParticipantsContacts):
|
||||
# These require a `q` parameter (support types for convenience)
|
||||
filter = filter('')
|
||||
else:
|
||||
filter = filter()
|
||||
|
||||
entity = await self.client.get_input_entity(entity)
|
||||
if search and (filter
|
||||
or not isinstance(entity, types.InputPeerChannel)):
|
||||
# We need to 'search' ourselves unless we have a PeerChannel
|
||||
search = search.lower()
|
||||
|
||||
self.filter_entity = lambda ent: (
|
||||
search in utils.get_display_name(ent).lower() or
|
||||
search in (getattr(ent, 'username', '') or None).lower()
|
||||
)
|
||||
else:
|
||||
self.filter_entity = lambda ent: True
|
||||
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
self.total = (await self.client(
|
||||
functions.channels.GetFullChannelRequest(entity)
|
||||
)).full_chat.participants_count
|
||||
|
||||
if self.limit == 0:
|
||||
raise StopAsyncIteration
|
||||
|
||||
self.seen = set()
|
||||
if aggressive and not filter:
|
||||
self.requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=types.ChannelParticipantsSearch(x),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
|
||||
else:
|
||||
self.requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=filter or types.ChannelParticipantsSearch(search),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
)]
|
||||
|
||||
elif isinstance(entity, types.InputPeerChat):
|
||||
full = await self.client(
|
||||
functions.messages.GetFullChatRequest(entity.chat_id))
|
||||
if not isinstance(
|
||||
full.full_chat.participants, types.ChatParticipants):
|
||||
# ChatParticipantsForbidden won't have ``.participants``
|
||||
self.total = 0
|
||||
raise StopAsyncIteration
|
||||
|
||||
self.total = len(full.full_chat.participants.participants)
|
||||
|
||||
users = {user.id: user for user in full.users}
|
||||
for participant in full.full_chat.participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not self.filter_entity(user):
|
||||
continue
|
||||
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
self.buffer.append(user)
|
||||
|
||||
return True
|
||||
else:
|
||||
self.total = 1
|
||||
if self.limit != 0:
|
||||
user = await self.client.get_entity(entity)
|
||||
if self.filter_entity(user):
|
||||
user.participant = None
|
||||
self.buffer.append(user)
|
||||
|
||||
return True
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
if not self.requests:
|
||||
return True
|
||||
|
||||
# Only care about the limit for the first request
|
||||
# (small amount of people, won't be aggressive).
|
||||
#
|
||||
# Most people won't care about getting exactly 12,345
|
||||
# members so it doesn't really matter not to be 100%
|
||||
# precise with being out of the offset/limit here.
|
||||
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
|
||||
if self.requests[0].offset > self.limit:
|
||||
return True
|
||||
|
||||
results = await self.client(self.requests)
|
||||
for i in reversed(range(len(self.requests))):
|
||||
participants = results[i]
|
||||
if not participants.users:
|
||||
self.requests.pop(i)
|
||||
continue
|
||||
|
||||
self.requests[i].offset += len(participants.participants)
|
||||
users = {user.id: user for user in participants.users}
|
||||
for participant in participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not self.filter_entity(user) or user.id in self.seen:
|
||||
continue
|
||||
|
||||
self.seen.add(participant.user_id)
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
self.buffer.append(user)
|
||||
|
||||
|
||||
class _AdminLogIter(RequestIter):
|
||||
async def _init(
|
||||
self, entity, admins, search, min_id, max_id,
|
||||
join, leave, invite, restrict, unrestrict, ban, unban,
|
||||
promote, demote, info, settings, pinned, edit, delete
|
||||
):
|
||||
if any((join, leave, invite, restrict, unrestrict, ban, unban,
|
||||
promote, demote, info, settings, pinned, edit, delete)):
|
||||
events_filter = types.ChannelAdminLogEventsFilter(
|
||||
join=join, leave=leave, invite=invite, ban=restrict,
|
||||
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
|
||||
demote=demote, info=info, settings=settings, pinned=pinned,
|
||||
edit=edit, delete=delete
|
||||
)
|
||||
else:
|
||||
events_filter = None
|
||||
|
||||
self.entity = await self.client.get_input_entity(entity)
|
||||
|
||||
admin_list = []
|
||||
if admins:
|
||||
if not utils.is_list_like(admins):
|
||||
admins = (admins,)
|
||||
|
||||
for admin in admins:
|
||||
admin_list.append(await self.client.get_input_entity(admin))
|
||||
|
||||
self.request = functions.channels.GetAdminLogRequest(
|
||||
self.entity, q=search or '', min_id=min_id, max_id=max_id,
|
||||
limit=0, events_filter=events_filter, admins=admin_list or None
|
||||
)
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
self.request.limit = min(self.left, 100)
|
||||
r = await self.client(self.request)
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
self.request.max_id = min((e.id for e in r.events), default=0)
|
||||
for ev in r.events:
|
||||
if isinstance(ev.action,
|
||||
types.ChannelAdminLogEventActionEditMessage):
|
||||
ev.action.prev_message._finish_init(
|
||||
self.client, entities, self.entity)
|
||||
|
||||
ev.action.new_message._finish_init(
|
||||
self.client, entities, self.entity)
|
||||
|
||||
elif isinstance(ev.action,
|
||||
types.ChannelAdminLogEventActionDeleteMessage):
|
||||
ev.action.message._finish_init(
|
||||
self.client, entities, self.entity)
|
||||
|
||||
self.buffer.append(custom.AdminLogEvent(ev, entities))
|
||||
|
||||
if len(r.events) < self.request.limit:
|
||||
return True
|
||||
|
||||
|
||||
class ChatMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
@async_generator
|
||||
async def iter_participants(
|
||||
def iter_participants(
|
||||
self, entity, limit=None, *, search='',
|
||||
filter=None, aggressive=False, _total=None):
|
||||
"""
|
||||
|
@ -62,138 +235,23 @@ class ChatMethods(UserMethods):
|
|||
matched :tl:`ChannelParticipant` type for channels/megagroups
|
||||
or :tl:`ChatParticipants` for normal chats.
|
||||
"""
|
||||
if isinstance(filter, type):
|
||||
if filter in (types.ChannelParticipantsBanned,
|
||||
types.ChannelParticipantsKicked,
|
||||
types.ChannelParticipantsSearch,
|
||||
types.ChannelParticipantsContacts):
|
||||
# These require a `q` parameter (support types for convenience)
|
||||
filter = filter('')
|
||||
else:
|
||||
filter = filter()
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
if search and (filter
|
||||
or not isinstance(entity, types.InputPeerChannel)):
|
||||
# We need to 'search' ourselves unless we have a PeerChannel
|
||||
search = search.lower()
|
||||
|
||||
def filter_entity(ent):
|
||||
return search in utils.get_display_name(ent).lower() or\
|
||||
search in (getattr(ent, 'username', '') or None).lower()
|
||||
else:
|
||||
def filter_entity(ent):
|
||||
return True
|
||||
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if isinstance(entity, types.InputPeerChannel):
|
||||
if _total:
|
||||
_total[0] = (await self(
|
||||
functions.channels.GetFullChannelRequest(entity)
|
||||
)).full_chat.participants_count
|
||||
|
||||
if limit == 0:
|
||||
return
|
||||
|
||||
seen = set()
|
||||
if aggressive and not filter:
|
||||
requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=types.ChannelParticipantsSearch(x),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
|
||||
else:
|
||||
requests = [functions.channels.GetParticipantsRequest(
|
||||
channel=entity,
|
||||
filter=filter or types.ChannelParticipantsSearch(search),
|
||||
offset=0,
|
||||
limit=200,
|
||||
hash=0
|
||||
)]
|
||||
|
||||
while requests:
|
||||
# Only care about the limit for the first request
|
||||
# (small amount of people, won't be aggressive).
|
||||
#
|
||||
# Most people won't care about getting exactly 12,345
|
||||
# members so it doesn't really matter not to be 100%
|
||||
# precise with being out of the offset/limit here.
|
||||
requests[0].limit = min(limit - requests[0].offset, 200)
|
||||
if requests[0].offset > limit:
|
||||
break
|
||||
|
||||
results = await self(requests)
|
||||
for i in reversed(range(len(requests))):
|
||||
participants = results[i]
|
||||
if not participants.users:
|
||||
requests.pop(i)
|
||||
else:
|
||||
requests[i].offset += len(participants.participants)
|
||||
users = {user.id: user for user in participants.users}
|
||||
for participant in participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not filter_entity(user) or user.id in seen:
|
||||
continue
|
||||
|
||||
seen.add(participant.user_id)
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
await yield_(user)
|
||||
if len(seen) >= limit:
|
||||
return
|
||||
|
||||
elif isinstance(entity, types.InputPeerChat):
|
||||
full = await self(
|
||||
functions.messages.GetFullChatRequest(entity.chat_id))
|
||||
if not isinstance(
|
||||
full.full_chat.participants, types.ChatParticipants):
|
||||
# ChatParticipantsForbidden won't have ``.participants``
|
||||
if _total:
|
||||
_total[0] = 0
|
||||
return
|
||||
|
||||
if _total:
|
||||
_total[0] = len(full.full_chat.participants.participants)
|
||||
|
||||
have = 0
|
||||
users = {user.id: user for user in full.users}
|
||||
for participant in full.full_chat.participants.participants:
|
||||
user = users[participant.user_id]
|
||||
if not filter_entity(user):
|
||||
continue
|
||||
have += 1
|
||||
if have > limit:
|
||||
break
|
||||
else:
|
||||
user = users[participant.user_id]
|
||||
user.participant = participant
|
||||
await yield_(user)
|
||||
else:
|
||||
if _total:
|
||||
_total[0] = 1
|
||||
if limit != 0:
|
||||
user = await self.get_entity(entity)
|
||||
if filter_entity(user):
|
||||
user.participant = None
|
||||
await yield_(user)
|
||||
return _ParticipantsIter(
|
||||
self,
|
||||
limit,
|
||||
entity=entity,
|
||||
filter=filter,
|
||||
search=search,
|
||||
aggressive=aggressive
|
||||
)
|
||||
|
||||
async def get_participants(self, *args, **kwargs):
|
||||
"""
|
||||
Same as `iter_participants`, but returns a
|
||||
`TotalList <telethon.helpers.TotalList>` instead.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
participants = helpers.TotalList()
|
||||
async for x in self.iter_participants(*args, **kwargs):
|
||||
participants.append(x)
|
||||
participants.total = total[0]
|
||||
return participants
|
||||
return await self.iter_participants(*args, **kwargs).collect()
|
||||
|
||||
@async_generator
|
||||
async def iter_admin_log(
|
||||
def iter_admin_log(
|
||||
self, entity, limit=None, *, max_id=0, min_id=0, search=None,
|
||||
admins=None, join=None, leave=None, invite=None, restrict=None,
|
||||
unrestrict=None, ban=None, unban=None, promote=None, demote=None,
|
||||
|
@ -285,66 +343,34 @@ class ChatMethods(UserMethods):
|
|||
Yields:
|
||||
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
|
||||
"""
|
||||
if limit is None:
|
||||
limit = sys.maxsize
|
||||
elif limit <= 0:
|
||||
return
|
||||
|
||||
if any((join, leave, invite, restrict, unrestrict, ban, unban,
|
||||
promote, demote, info, settings, pinned, edit, delete)):
|
||||
events_filter = types.ChannelAdminLogEventsFilter(
|
||||
join=join, leave=leave, invite=invite, ban=restrict,
|
||||
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
|
||||
demote=demote, info=info, settings=settings, pinned=pinned,
|
||||
edit=edit, delete=delete
|
||||
)
|
||||
else:
|
||||
events_filter = None
|
||||
|
||||
entity = await self.get_input_entity(entity)
|
||||
|
||||
admin_list = []
|
||||
if admins:
|
||||
if not utils.is_list_like(admins):
|
||||
admins = (admins,)
|
||||
|
||||
for admin in admins:
|
||||
admin_list.append(await self.get_input_entity(admin))
|
||||
|
||||
request = functions.channels.GetAdminLogRequest(
|
||||
entity, q=search or '', min_id=min_id, max_id=max_id,
|
||||
limit=0, events_filter=events_filter, admins=admin_list or None
|
||||
return _AdminLogIter(
|
||||
self,
|
||||
limit,
|
||||
entity=entity,
|
||||
admins=admins,
|
||||
search=search,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
join=join,
|
||||
leave=leave,
|
||||
invite=invite,
|
||||
restrict=restrict,
|
||||
unrestrict=unrestrict,
|
||||
ban=ban,
|
||||
unban=unban,
|
||||
promote=promote,
|
||||
demote=demote,
|
||||
info=info,
|
||||
settings=settings,
|
||||
pinned=pinned,
|
||||
edit=edit,
|
||||
delete=delete
|
||||
)
|
||||
while limit > 0:
|
||||
request.limit = min(limit, 100)
|
||||
result = await self(request)
|
||||
limit -= len(result.events)
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(result.users, result.chats)}
|
||||
|
||||
request.max_id = min((e.id for e in result.events), default=0)
|
||||
for ev in result.events:
|
||||
if isinstance(ev.action,
|
||||
types.ChannelAdminLogEventActionEditMessage):
|
||||
ev.action.prev_message._finish_init(self, entities, entity)
|
||||
ev.action.new_message._finish_init(self, entities, entity)
|
||||
|
||||
elif isinstance(ev.action,
|
||||
types.ChannelAdminLogEventActionDeleteMessage):
|
||||
ev.action.message._finish_init(self, entities, entity)
|
||||
|
||||
await yield_(custom.AdminLogEvent(ev, entities))
|
||||
|
||||
if len(result.events) < request.limit:
|
||||
break
|
||||
|
||||
async def get_admin_log(self, *args, **kwargs):
|
||||
"""
|
||||
Same as `iter_admin_log`, but returns a ``list`` instead.
|
||||
"""
|
||||
admin_log = []
|
||||
async for x in self.iter_admin_log(*args, **kwargs):
|
||||
admin_log.append(x)
|
||||
return admin_log
|
||||
return await self.iter_admin_log(*args, **kwargs).collect()
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,18 +1,101 @@
|
|||
import itertools
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils, helpers
|
||||
from .. import utils
|
||||
from ..requestiter import RequestIter
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
|
||||
class _DialogsIter(RequestIter):
|
||||
async def _init(
|
||||
self, offset_date, offset_id, offset_peer, ignore_migrated
|
||||
):
|
||||
self.request = functions.messages.GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=1,
|
||||
hash=0
|
||||
)
|
||||
|
||||
if self.limit == 0:
|
||||
# Special case, get a single dialog and determine count
|
||||
dialogs = await self.client(self.request)
|
||||
self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
|
||||
raise StopAsyncIteration
|
||||
|
||||
self.seen = set()
|
||||
self.offset_date = offset_date
|
||||
self.ignore_migrated = ignore_migrated
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
self.request.limit = min(self.left, 100)
|
||||
r = await self.client(self.request)
|
||||
|
||||
self.total = getattr(r, 'count', len(r.dialogs))
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
messages = {}
|
||||
for m in r.messages:
|
||||
m._finish_init(self, entities, None)
|
||||
messages[m.id] = m
|
||||
|
||||
for d in r.dialogs:
|
||||
# We check the offset date here because Telegram may ignore it
|
||||
if self.offset_date:
|
||||
date = getattr(messages.get(
|
||||
d.top_message, None), 'date', None)
|
||||
|
||||
if not date or date.timestamp() > self.offset_date.timestamp():
|
||||
continue
|
||||
|
||||
peer_id = utils.get_peer_id(d.peer)
|
||||
if peer_id not in self.seen:
|
||||
self.seen.add(peer_id)
|
||||
cd = custom.Dialog(self, d, entities, messages)
|
||||
if cd.dialog.pts:
|
||||
self.client._channel_pts[cd.id] = cd.dialog.pts
|
||||
|
||||
if not self.ignore_migrated or getattr(
|
||||
cd.entity, 'migrated_to', None) is None:
|
||||
self.buffer.append(cd)
|
||||
|
||||
if len(r.dialogs) < self.request.limit\
|
||||
or not isinstance(r, types.messages.DialogsSlice):
|
||||
# Less than we requested means we reached the end, or
|
||||
# we didn't get a DialogsSlice which means we got all.
|
||||
return True
|
||||
|
||||
if self.request.offset_id == r.messages[-1].id:
|
||||
# In some very rare cases this will get stuck in an infinite
|
||||
# loop, where the offsets will get reused over and over. If
|
||||
# the new offset is the same as the one before, break already.
|
||||
return True
|
||||
|
||||
self.request.offset_id = r.messages[-1].id
|
||||
self.request.exclude_pinned = True
|
||||
self.request.offset_date = r.messages[-1].date
|
||||
self.request.offset_peer =\
|
||||
entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
||||
|
||||
|
||||
class _DraftsIter(RequestIter):
|
||||
async def _init(self, **kwargs):
|
||||
r = await self.client(functions.messages.GetAllDraftsRequest())
|
||||
self.buffer.extend(custom.Draft._from_update(self.client, u)
|
||||
for u in r.updates)
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
return []
|
||||
|
||||
|
||||
class DialogMethods(UserMethods):
|
||||
|
||||
# region Public methods
|
||||
|
||||
@async_generator
|
||||
async def iter_dialogs(
|
||||
def iter_dialogs(
|
||||
self, limit=None, *, offset_date=None, offset_id=0,
|
||||
offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
|
||||
_total=None):
|
||||
|
@ -50,99 +133,23 @@ class DialogMethods(UserMethods):
|
|||
Yields:
|
||||
Instances of `telethon.tl.custom.dialog.Dialog`.
|
||||
"""
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if limit == 0:
|
||||
if not _total:
|
||||
return
|
||||
# Special case, get a single dialog and determine count
|
||||
dialogs = await self(functions.messages.GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=1,
|
||||
hash=0
|
||||
))
|
||||
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
|
||||
return
|
||||
|
||||
seen = set()
|
||||
req = functions.messages.GetDialogsRequest(
|
||||
return _DialogsIter(
|
||||
self,
|
||||
limit,
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=0,
|
||||
hash=0
|
||||
ignore_migrated=ignore_migrated
|
||||
)
|
||||
while len(seen) < limit:
|
||||
req.limit = min(limit - len(seen), 100)
|
||||
r = await self(req)
|
||||
|
||||
if _total:
|
||||
_total[0] = getattr(r, 'count', len(r.dialogs))
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
messages = {}
|
||||
for m in r.messages:
|
||||
m._finish_init(self, entities, None)
|
||||
messages[m.id] = m
|
||||
|
||||
# Happens when there are pinned dialogs
|
||||
if len(r.dialogs) > limit:
|
||||
r.dialogs = r.dialogs[:limit]
|
||||
|
||||
for d in r.dialogs:
|
||||
if offset_date:
|
||||
date = getattr(messages.get(
|
||||
d.top_message, None), 'date', None)
|
||||
|
||||
if not date or date.timestamp() > offset_date.timestamp():
|
||||
continue
|
||||
|
||||
peer_id = utils.get_peer_id(d.peer)
|
||||
if peer_id not in seen:
|
||||
seen.add(peer_id)
|
||||
cd = custom.Dialog(self, d, entities, messages)
|
||||
if cd.dialog.pts:
|
||||
self._channel_pts[cd.id] = cd.dialog.pts
|
||||
|
||||
if not ignore_migrated or getattr(
|
||||
cd.entity, 'migrated_to', None) is None:
|
||||
await yield_(cd)
|
||||
|
||||
if len(r.dialogs) < req.limit\
|
||||
or not isinstance(r, types.messages.DialogsSlice):
|
||||
# Less than we requested means we reached the end, or
|
||||
# we didn't get a DialogsSlice which means we got all.
|
||||
break
|
||||
|
||||
req.offset_date = r.messages[-1].date
|
||||
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
||||
if req.offset_id == r.messages[-1].id:
|
||||
# In some very rare cases this will get stuck in an infinite
|
||||
# loop, where the offsets will get reused over and over. If
|
||||
# the new offset is the same as the one before, break already.
|
||||
break
|
||||
|
||||
req.offset_id = r.messages[-1].id
|
||||
req.exclude_pinned = True
|
||||
|
||||
async def get_dialogs(self, *args, **kwargs):
|
||||
"""
|
||||
Same as `iter_dialogs`, but returns a
|
||||
`TotalList <telethon.helpers.TotalList>` instead.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
dialogs = helpers.TotalList()
|
||||
async for x in self.iter_dialogs(*args, **kwargs):
|
||||
dialogs.append(x)
|
||||
dialogs.total = total[0]
|
||||
return dialogs
|
||||
return await self.iter_dialogs(*args, **kwargs).collect()
|
||||
|
||||
@async_generator
|
||||
async def iter_drafts(self):
|
||||
def iter_drafts(self):
|
||||
"""
|
||||
Iterator over all open draft messages.
|
||||
|
||||
|
@ -151,18 +158,14 @@ class DialogMethods(UserMethods):
|
|||
to change the message or `telethon.tl.custom.draft.Draft.delete`
|
||||
among other things.
|
||||
"""
|
||||
r = await self(functions.messages.GetAllDraftsRequest())
|
||||
for update in r.updates:
|
||||
await yield_(custom.Draft._from_update(self, update))
|
||||
# TODO Passing a limit here makes no sense
|
||||
return _DraftsIter(self, None)
|
||||
|
||||
async def get_drafts(self):
|
||||
"""
|
||||
Same as :meth:`iter_drafts`, but returns a list instead.
|
||||
"""
|
||||
result = []
|
||||
async for x in self.iter_drafts():
|
||||
result.append(x)
|
||||
return result
|
||||
return await self.iter_drafts().collect()
|
||||
|
||||
def conversation(
|
||||
self, entity,
|
||||
|
|
|
@ -1,14 +1,282 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
import time
|
||||
|
||||
from async_generator import async_generator, yield_
|
||||
|
||||
from .messageparse import MessageParseMethods
|
||||
from .uploads import UploadMethods
|
||||
from .buttons import ButtonMethods
|
||||
from .. import helpers, utils, errors
|
||||
from ..tl import types, functions
|
||||
from ..requestiter import RequestIter
|
||||
|
||||
|
||||
# TODO Maybe RequestIter could rather have the update offset here?
|
||||
# Maybe init should return the request to be used and it be
|
||||
# called automatically? And another method to just process it.
|
||||
class _MessagesIter(RequestIter):
|
||||
"""
|
||||
Common factor for all requests that need to iterate over messages.
|
||||
"""
|
||||
async def _init(
|
||||
self, entity, offset_id, min_id, max_id, from_user,
|
||||
batch_size, offset_date, add_offset, filter, search
|
||||
):
|
||||
# Note that entity being ``None`` will perform a global search.
|
||||
if entity:
|
||||
self.entity = await self.client.get_input_entity(entity)
|
||||
else:
|
||||
self.entity = None
|
||||
if self.reverse:
|
||||
raise ValueError('Cannot reverse global search')
|
||||
|
||||
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
||||
# (starting from last_id - 100), the request will return nothing.
|
||||
#
|
||||
# We can emulate their behaviour locally by setting offset = max_id
|
||||
# and simply stopping once we hit a message with ID <= min_id.
|
||||
if self.reverse:
|
||||
offset_id = max(offset_id, min_id)
|
||||
if offset_id and max_id:
|
||||
if max_id - offset_id <= 1:
|
||||
raise StopAsyncIteration
|
||||
|
||||
if not max_id:
|
||||
max_id = float('inf')
|
||||
else:
|
||||
offset_id = max(offset_id, max_id)
|
||||
if offset_id and min_id:
|
||||
if offset_id - min_id <= 1:
|
||||
raise StopAsyncIteration
|
||||
|
||||
if self.reverse:
|
||||
if offset_id:
|
||||
offset_id += 1
|
||||
else:
|
||||
offset_id = 1
|
||||
|
||||
if from_user:
|
||||
from_user = await self.client.get_input_entity(from_user)
|
||||
if not isinstance(from_user, (
|
||||
types.InputPeerUser, types.InputPeerSelf)):
|
||||
from_user = None # Ignore from_user unless it's a user
|
||||
|
||||
if from_user:
|
||||
self.from_id = await self.client.get_peer_id(from_user)
|
||||
else:
|
||||
self.from_id = None
|
||||
|
||||
if not self.entity:
|
||||
self.request = functions.messages.SearchGlobalRequest(
|
||||
q=search or '',
|
||||
offset_date=offset_date,
|
||||
offset_peer=types.InputPeerEmpty(),
|
||||
offset_id=offset_id,
|
||||
limit=1
|
||||
)
|
||||
elif search is not None or filter or from_user:
|
||||
if filter is None:
|
||||
filter = types.InputMessagesFilterEmpty()
|
||||
|
||||
# Telegram completely ignores `from_id` in private chats
|
||||
if isinstance(
|
||||
self.entity, (types.InputPeerUser, types.InputPeerSelf)):
|
||||
# Don't bother sending `from_user` (it's ignored anyway),
|
||||
# but keep `from_id` defined above to check it locally.
|
||||
from_user = None
|
||||
else:
|
||||
# Do send `from_user` to do the filtering server-side,
|
||||
# and set `from_id` to None to avoid checking it locally.
|
||||
self.from_id = None
|
||||
|
||||
self.request = functions.messages.SearchRequest(
|
||||
peer=self.entity,
|
||||
q=search or '',
|
||||
filter=filter() if isinstance(filter, type) else filter,
|
||||
min_date=None,
|
||||
max_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
add_offset=add_offset,
|
||||
limit=0, # Search actually returns 0 items if we ask it to
|
||||
max_id=0,
|
||||
min_id=0,
|
||||
hash=0,
|
||||
from_id=from_user
|
||||
)
|
||||
else:
|
||||
self.request = functions.messages.GetHistoryRequest(
|
||||
peer=self.entity,
|
||||
limit=1,
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
min_id=0,
|
||||
max_id=0,
|
||||
add_offset=add_offset,
|
||||
hash=0
|
||||
)
|
||||
|
||||
if self.limit == 0:
|
||||
# No messages, but we still need to know the total message count
|
||||
result = await self.client(self.request)
|
||||
if isinstance(result, types.messages.MessagesNotModified):
|
||||
self.total = result.count
|
||||
else:
|
||||
self.total = getattr(result, 'count', len(result.messages))
|
||||
raise StopAsyncIteration
|
||||
|
||||
if self.wait_time is None:
|
||||
self.wait_time = 1 if self.limit > 3000 else 0
|
||||
|
||||
# Telegram has a hard limit of 100.
|
||||
# We don't need to fetch 100 if the limit is less.
|
||||
self.batch_size = min(max(batch_size, 1), min(100, self.limit))
|
||||
|
||||
# When going in reverse we need an offset of `-limit`, but we
|
||||
# also want to respect what the user passed, so add them together.
|
||||
if self.reverse:
|
||||
self.request.add_offset -= self.batch_size
|
||||
|
||||
self.add_offset = add_offset
|
||||
self.max_id = max_id
|
||||
self.min_id = min_id
|
||||
self.last_id = 0 if self.reverse else float('inf')
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
self.request.limit = min(self.left, self.batch_size)
|
||||
if self.reverse and self.request.limit != self.batch_size:
|
||||
# Remember that we need -limit when going in reverse
|
||||
self.request.add_offset = self.add_offset - self.request.limit
|
||||
|
||||
r = await self.client(self.request)
|
||||
self.total = getattr(r, 'count', len(r.messages))
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
messages = reversed(r.messages) if self.reverse else r.messages
|
||||
for message in messages:
|
||||
if (isinstance(message, types.MessageEmpty)
|
||||
or self.from_id and message.from_id != self.from_id):
|
||||
continue
|
||||
|
||||
if not self._message_in_range(message):
|
||||
return True
|
||||
|
||||
# There has been reports that on bad connections this method
|
||||
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||
# is an attempt to avoid these duplicates, since the message
|
||||
# IDs are returned in descending order (or asc if reverse).
|
||||
self.last_id = message.id
|
||||
message._finish_init(self.client, entities, self.entity)
|
||||
self.buffer.append(message)
|
||||
|
||||
if len(r.messages) < self.request.limit:
|
||||
return True
|
||||
|
||||
# Get the last message that's not empty (in some rare cases
|
||||
# it can happen that the last message is :tl:`MessageEmpty`)
|
||||
if self.buffer:
|
||||
self._update_offset(self.buffer[-1])
|
||||
else:
|
||||
# There are some cases where all the messages we get start
|
||||
# being empty. This can happen on migrated mega-groups if
|
||||
# the history was cleared, and we're using search. Telegram
|
||||
# acts incredibly weird sometimes. Messages are returned but
|
||||
# only "empty", not their contents. If this is the case we
|
||||
# should just give up since there won't be any new Message.
|
||||
return True
|
||||
|
||||
def _message_in_range(self, message):
|
||||
"""
|
||||
Determine whether the given message is in the range or
|
||||
it should be ignored (and avoid loading more chunks).
|
||||
"""
|
||||
# No entity means message IDs between chats may vary
|
||||
if self.entity:
|
||||
if self.reverse:
|
||||
if message.id <= self.last_id or message.id >= self.max_id:
|
||||
return False
|
||||
else:
|
||||
if message.id >= self.last_id or message.id <= self.min_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _update_offset(self, last_message):
|
||||
"""
|
||||
After making the request, update its offset with the last message.
|
||||
"""
|
||||
self.request.offset_id = last_message.id
|
||||
if self.reverse:
|
||||
# We want to skip the one we already have
|
||||
self.request.offset_id += 1
|
||||
|
||||
if isinstance(self.request, functions.messages.SearchRequest):
|
||||
# Unlike getHistory and searchGlobal that use *offset* date,
|
||||
# this is *max* date. This means that doing a search in reverse
|
||||
# will break it. Since it's not really needed once we're going
|
||||
# (only for the first request), it's safe to just clear it off.
|
||||
self.request.max_date = None
|
||||
else:
|
||||
# getHistory and searchGlobal call it offset_date
|
||||
self.request.offset_date = last_message.date
|
||||
|
||||
if isinstance(self.request, functions.messages.SearchGlobalRequest):
|
||||
self.request.offset_peer = last_message.input_chat
|
||||
|
||||
|
||||
class _IDsIter(RequestIter):
|
||||
async def _init(self, entity, ids):
|
||||
# TODO We never actually split IDs in chunks, but maybe we should
|
||||
if not utils.is_list_like(ids):
|
||||
self.ids = [ids]
|
||||
elif not ids:
|
||||
raise StopAsyncIteration
|
||||
elif self.reverse:
|
||||
self.ids = list(reversed(ids))
|
||||
else:
|
||||
self.ids = ids
|
||||
|
||||
if entity:
|
||||
entity = await self.client.get_input_entity(entity)
|
||||
|
||||
self.total = len(ids)
|
||||
|
||||
from_id = None # By default, no need to validate from_id
|
||||
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
|
||||
try:
|
||||
r = await self.client(
|
||||
functions.channels.GetMessagesRequest(entity, ids))
|
||||
except errors.MessageIdsEmptyError:
|
||||
# All IDs were invalid, use a dummy result
|
||||
r = types.messages.MessagesNotModified(len(ids))
|
||||
else:
|
||||
r = await self.client(functions.messages.GetMessagesRequest(ids))
|
||||
if entity:
|
||||
from_id = await self.client.get_peer_id(entity)
|
||||
|
||||
if isinstance(r, types.messages.MessagesNotModified):
|
||||
self.buffer.extend(None for _ in ids)
|
||||
return
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
# Telegram seems to return the messages in the order in which
|
||||
# we asked them for, so we don't need to check it ourselves,
|
||||
# unless some messages were invalid in which case Telegram
|
||||
# may decide to not send them at all.
|
||||
#
|
||||
# The passed message IDs may not belong to the desired entity
|
||||
# since the user can enter arbitrary numbers which can belong to
|
||||
# arbitrary chats. Validate these unless ``from_id is None``.
|
||||
for message in r.messages:
|
||||
if isinstance(message, types.MessageEmpty) or (
|
||||
from_id and message.chat_id != from_id):
|
||||
self.buffer.append(None)
|
||||
else:
|
||||
message._finish_init(self.client, entities, entity)
|
||||
self.buffer.append(message)
|
||||
|
||||
async def _load_next_chunk(self):
|
||||
return True # no next chunk, all done in init
|
||||
|
||||
|
||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||
|
@ -17,8 +285,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
|
||||
# region Message retrieval
|
||||
|
||||
@async_generator
|
||||
async def iter_messages(
|
||||
def iter_messages(
|
||||
self, entity, limit=None, *, offset_date=None, offset_id=0,
|
||||
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
||||
from_user=None, batch_size=100, wait_time=None, ids=None,
|
||||
|
@ -133,208 +400,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
an higher limit, so you're free to set the ``batch_size`` that
|
||||
you think may be good.
|
||||
"""
|
||||
# Note that entity being ``None`` is intended to get messages by
|
||||
# ID under no specific chat, and also to request a global search.
|
||||
if entity:
|
||||
entity = await self.get_input_entity(entity)
|
||||
|
||||
if ids:
|
||||
if not utils.is_list_like(ids):
|
||||
ids = (ids,)
|
||||
if reverse:
|
||||
ids = list(reversed(ids))
|
||||
async for x in self._iter_ids(entity, ids, total=_total):
|
||||
await yield_(x)
|
||||
return
|
||||
if ids is not None:
|
||||
return _IDsIter(self, limit, entity=entity, ids=ids)
|
||||
|
||||
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
||||
# (starting from last_id - 100), the request will return nothing.
|
||||
#
|
||||
# We can emulate their behaviour locally by setting offset = max_id
|
||||
# and simply stopping once we hit a message with ID <= min_id.
|
||||
if reverse:
|
||||
offset_id = max(offset_id, min_id)
|
||||
if offset_id and max_id:
|
||||
if max_id - offset_id <= 1:
|
||||
return
|
||||
|
||||
if not max_id:
|
||||
max_id = float('inf')
|
||||
else:
|
||||
offset_id = max(offset_id, max_id)
|
||||
if offset_id and min_id:
|
||||
if offset_id - min_id <= 1:
|
||||
return
|
||||
|
||||
if reverse:
|
||||
if offset_id:
|
||||
offset_id += 1
|
||||
else:
|
||||
offset_id = 1
|
||||
|
||||
if from_user:
|
||||
from_user = await self.get_input_entity(from_user)
|
||||
if not isinstance(from_user, (
|
||||
types.InputPeerUser, types.InputPeerSelf)):
|
||||
from_user = None # Ignore from_user unless it's a user
|
||||
|
||||
from_id = (await self.get_peer_id(from_user)) if from_user else None
|
||||
|
||||
limit = float('inf') if limit is None else int(limit)
|
||||
if not entity:
|
||||
if reverse:
|
||||
raise ValueError('Cannot reverse global search')
|
||||
|
||||
reverse = None
|
||||
request = functions.messages.SearchGlobalRequest(
|
||||
q=search or '',
|
||||
offset_date=offset_date,
|
||||
offset_peer=types.InputPeerEmpty(),
|
||||
offset_id=offset_id,
|
||||
limit=1
|
||||
)
|
||||
elif search is not None or filter or from_user:
|
||||
if filter is None:
|
||||
filter = types.InputMessagesFilterEmpty()
|
||||
|
||||
# Telegram completely ignores `from_id` in private chats
|
||||
if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)):
|
||||
# Don't bother sending `from_user` (it's ignored anyway),
|
||||
# but keep `from_id` defined above to check it locally.
|
||||
from_user = None
|
||||
else:
|
||||
# Do send `from_user` to do the filtering server-side,
|
||||
# and set `from_id` to None to avoid checking it locally.
|
||||
from_id = None
|
||||
|
||||
request = functions.messages.SearchRequest(
|
||||
peer=entity,
|
||||
q=search or '',
|
||||
filter=filter() if isinstance(filter, type) else filter,
|
||||
min_date=None,
|
||||
max_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
add_offset=add_offset,
|
||||
limit=0, # Search actually returns 0 items if we ask it to
|
||||
max_id=0,
|
||||
min_id=0,
|
||||
hash=0,
|
||||
from_id=from_user
|
||||
)
|
||||
else:
|
||||
request = functions.messages.GetHistoryRequest(
|
||||
peer=entity,
|
||||
limit=1,
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
min_id=0,
|
||||
max_id=0,
|
||||
add_offset=add_offset,
|
||||
hash=0
|
||||
)
|
||||
|
||||
if limit == 0:
|
||||
if not _total:
|
||||
return
|
||||
# No messages, but we still need to know the total message count
|
||||
result = await self(request)
|
||||
if isinstance(result, types.messages.MessagesNotModified):
|
||||
_total[0] = result.count
|
||||
else:
|
||||
_total[0] = getattr(result, 'count', len(result.messages))
|
||||
return
|
||||
|
||||
if wait_time is None:
|
||||
wait_time = 1 if limit > 3000 else 0
|
||||
|
||||
have = 0
|
||||
last_id = 0 if reverse else float('inf')
|
||||
|
||||
# Telegram has a hard limit of 100.
|
||||
# We don't need to fetch 100 if the limit is less.
|
||||
batch_size = min(max(batch_size, 1), min(100, limit))
|
||||
|
||||
# When going in reverse we need an offset of `-limit`, but we
|
||||
# also want to respect what the user passed, so add them together.
|
||||
if reverse:
|
||||
request.add_offset -= batch_size
|
||||
|
||||
while have < limit:
|
||||
start = time.time()
|
||||
|
||||
request.limit = min(limit - have, batch_size)
|
||||
if reverse and request.limit != batch_size:
|
||||
# Remember that we need -limit when going in reverse
|
||||
request.add_offset = add_offset - request.limit
|
||||
|
||||
r = await self(request)
|
||||
if _total:
|
||||
_total[0] = getattr(r, 'count', len(r.messages))
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
messages = reversed(r.messages) if reverse else r.messages
|
||||
for message in messages:
|
||||
if (isinstance(message, types.MessageEmpty)
|
||||
or from_id and message.from_id != from_id):
|
||||
continue
|
||||
|
||||
if reverse is None:
|
||||
pass
|
||||
elif reverse:
|
||||
if message.id <= last_id or message.id >= max_id:
|
||||
return
|
||||
else:
|
||||
if message.id >= last_id or message.id <= min_id:
|
||||
return
|
||||
|
||||
# There has been reports that on bad connections this method
|
||||
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||
# is an attempt to avoid these duplicates, since the message
|
||||
# IDs are returned in descending order (or asc if reverse).
|
||||
last_id = message.id
|
||||
|
||||
message._finish_init(self, entities, entity)
|
||||
await yield_(message)
|
||||
have += 1
|
||||
|
||||
if len(r.messages) < request.limit:
|
||||
break
|
||||
|
||||
# Find the first message that's not empty (in some rare cases
|
||||
# it can happen that the last message is :tl:`MessageEmpty`)
|
||||
last_message = None
|
||||
messages = r.messages if reverse else reversed(r.messages)
|
||||
for m in messages:
|
||||
if not isinstance(m, types.MessageEmpty):
|
||||
last_message = m
|
||||
break
|
||||
|
||||
if last_message is None:
|
||||
# There are some cases where all the messages we get start
|
||||
# being empty. This can happen on migrated mega-groups if
|
||||
# the history was cleared, and we're using search. Telegram
|
||||
# acts incredibly weird sometimes. Messages are returned but
|
||||
# only "empty", not their contents. If this is the case we
|
||||
# should just give up since there won't be any new Message.
|
||||
break
|
||||
else:
|
||||
request.offset_id = last_message.id
|
||||
if isinstance(request, functions.messages.SearchRequest):
|
||||
request.max_date = last_message.date
|
||||
else:
|
||||
# getHistory and searchGlobal call it offset_date
|
||||
request.offset_date = last_message.date
|
||||
|
||||
if isinstance(request, functions.messages.SearchGlobalRequest):
|
||||
request.offset_peer = last_message.input_chat
|
||||
elif reverse:
|
||||
# We want to skip the one we already have
|
||||
request.offset_id += 1
|
||||
|
||||
await asyncio.sleep(
|
||||
max(wait_time - (time.time() - start), 0), loop=self._loop)
|
||||
return _MessagesIter(
|
||||
client=self,
|
||||
reverse=reverse,
|
||||
wait_time=wait_time,
|
||||
limit=limit,
|
||||
entity=entity,
|
||||
offset_id=offset_id,
|
||||
min_id=min_id,
|
||||
max_id=max_id,
|
||||
from_user=from_user,
|
||||
batch_size=batch_size,
|
||||
offset_date=offset_date,
|
||||
add_offset=add_offset,
|
||||
filter=filter,
|
||||
search=search
|
||||
)
|
||||
|
||||
async def get_messages(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -353,23 +438,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
a single `Message <telethon.tl.custom.message.Message>` will be
|
||||
returned for convenience instead of a list.
|
||||
"""
|
||||
total = [0]
|
||||
kwargs['_total'] = total
|
||||
if len(args) == 1 and 'limit' not in kwargs:
|
||||
if 'min_id' in kwargs and 'max_id' in kwargs:
|
||||
kwargs['limit'] = None
|
||||
else:
|
||||
kwargs['limit'] = 1
|
||||
|
||||
msgs = helpers.TotalList()
|
||||
async for x in self.iter_messages(*args, **kwargs):
|
||||
msgs.append(x)
|
||||
msgs.total = total[0]
|
||||
if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
|
||||
# Check for empty list to handle InputMessageReplyTo
|
||||
return msgs[0] if msgs else None
|
||||
it = self.iter_messages(*args, **kwargs)
|
||||
|
||||
return msgs
|
||||
ids = kwargs.get('ids')
|
||||
if ids and not utils.is_list_like(ids):
|
||||
async for message in it:
|
||||
return message
|
||||
else:
|
||||
# Iterator exhausted = empty, to handle InputMessageReplyTo
|
||||
return None
|
||||
|
||||
return await it.collect()
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -799,52 +884,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
# endregion
|
||||
|
||||
# endregion
|
||||
|
||||
# region Private methods
|
||||
|
||||
@async_generator
|
||||
async def _iter_ids(self, entity, ids, total):
|
||||
"""
|
||||
Special case for `iter_messages` when it should only fetch some IDs.
|
||||
"""
|
||||
if total:
|
||||
total[0] = len(ids)
|
||||
|
||||
from_id = None # By default, no need to validate from_id
|
||||
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
|
||||
try:
|
||||
r = await self(
|
||||
functions.channels.GetMessagesRequest(entity, ids))
|
||||
except errors.MessageIdsEmptyError:
|
||||
# All IDs were invalid, use a dummy result
|
||||
r = types.messages.MessagesNotModified(len(ids))
|
||||
else:
|
||||
r = await self(functions.messages.GetMessagesRequest(ids))
|
||||
if entity:
|
||||
from_id = utils.get_peer_id(entity)
|
||||
|
||||
if isinstance(r, types.messages.MessagesNotModified):
|
||||
for _ in ids:
|
||||
await yield_(None)
|
||||
return
|
||||
|
||||
entities = {utils.get_peer_id(x): x
|
||||
for x in itertools.chain(r.users, r.chats)}
|
||||
|
||||
# Telegram seems to return the messages in the order in which
|
||||
# we asked them for, so we don't need to check it ourselves,
|
||||
# unless some messages were invalid in which case Telegram
|
||||
# may decide to not send them at all.
|
||||
#
|
||||
# The passed message IDs may not belong to the desired entity
|
||||
# since the user can enter arbitrary numbers which can belong to
|
||||
# arbitrary chats. Validate these unless ``from_id is None``.
|
||||
for message in r.messages:
|
||||
if isinstance(message, types.MessageEmpty) or (
|
||||
from_id and message.chat_id != from_id):
|
||||
await yield_(None)
|
||||
else:
|
||||
message._finish_init(self, entities, entity)
|
||||
await yield_(message)
|
||||
|
||||
# endregion
|
||||
|
|
130
telethon/requestiter.py
Normal file
130
telethon/requestiter.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from . import helpers
|
||||
|
||||
|
||||
# TODO There are two types of iterators for requests.
|
||||
# One has a limit of items to retrieve, and the
|
||||
# other has a list that must be called in chunks.
|
||||
# Make classes for both here so it's easy to use.
|
||||
class RequestIter(abc.ABC):
|
||||
"""
|
||||
Helper class to deal with requests that need offsets to iterate.
|
||||
|
||||
It has some facilities, such as automatically sleeping a desired
|
||||
amount of time between requests if needed (but not more).
|
||||
|
||||
Can be used synchronously if the event loop is not running and
|
||||
as an asynchronous iterator otherwise.
|
||||
|
||||
`limit` is the total amount of items that the iterator should return.
|
||||
This is handled on this base class, and will be always ``>= 0``.
|
||||
|
||||
`left` will be reset every time the iterator is used and will indicate
|
||||
the amount of items that should be emitted left, so that subclasses can
|
||||
be more efficient and fetch only as many items as they need.
|
||||
|
||||
Iterators may be used with ``reversed``, and their `reverse` flag will
|
||||
be set to ``True`` if that's the case. Note that if this flag is set,
|
||||
`buffer` should be filled in reverse too.
|
||||
"""
|
||||
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
|
||||
self.client = client
|
||||
self.reverse = reverse
|
||||
self.wait_time = wait_time
|
||||
self.kwargs = kwargs
|
||||
self.limit = max(float('inf') if limit is None else limit, 0)
|
||||
self.left = None
|
||||
self.buffer = None
|
||||
self.index = None
|
||||
self.total = None
|
||||
self.last_load = None
|
||||
|
||||
async def _init(self, **kwargs):
|
||||
"""
|
||||
Called when asynchronous initialization is necessary. All keyword
|
||||
arguments passed to `__init__` will be forwarded here, and it's
|
||||
preferable to use named arguments in the subclasses without defaults
|
||||
to avoid forgetting or misspelling any of them.
|
||||
|
||||
This method may ``raise StopAsyncIteration`` if it cannot continue.
|
||||
|
||||
This method may actually fill the initial buffer if it needs to.
|
||||
"""
|
||||
|
||||
async def __anext__(self):
|
||||
if self.buffer is None:
|
||||
self.buffer = []
|
||||
await self._init(**self.kwargs)
|
||||
|
||||
if self.left <= 0: # <= 0 because subclasses may change it
|
||||
raise StopAsyncIteration
|
||||
|
||||
if self.index == len(self.buffer):
|
||||
# asyncio will handle times <= 0 to sleep 0 seconds
|
||||
if self.wait_time:
|
||||
await asyncio.sleep(
|
||||
self.wait_time - (time.time() - self.last_load),
|
||||
loop=self.client.loop
|
||||
)
|
||||
self.last_load = time.time()
|
||||
|
||||
self.index = 0
|
||||
self.buffer = []
|
||||
if await self._load_next_chunk():
|
||||
self.left = len(self.buffer)
|
||||
|
||||
if not self.buffer:
|
||||
raise StopAsyncIteration
|
||||
|
||||
result = self.buffer[self.index]
|
||||
self.left -= 1
|
||||
self.index += 1
|
||||
return result
|
||||
|
||||
def __aiter__(self):
|
||||
self.buffer = None
|
||||
self.index = 0
|
||||
self.last_load = 0
|
||||
self.left = self.limit
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
if self.client.loop.is_running():
|
||||
raise RuntimeError(
|
||||
'You must use "async for" if the event loop '
|
||||
'is running (i.e. you are inside an "async def")'
|
||||
)
|
||||
|
||||
return self.__aiter__()
|
||||
|
||||
async def collect(self):
|
||||
"""
|
||||
Create a `self` iterator and collect it into a `TotalList`
|
||||
(a normal list with a `.total` attribute).
|
||||
"""
|
||||
result = helpers.TotalList()
|
||||
async for message in self:
|
||||
result.append(message)
|
||||
|
||||
result.total = self.total
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _load_next_chunk(self):
|
||||
"""
|
||||
Called when the next chunk is necessary.
|
||||
|
||||
It should extend the `buffer` with new items.
|
||||
|
||||
It should return ``True`` if it's the last chunk,
|
||||
after which moment the method won't be called again
|
||||
during the same iteration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __reversed__(self):
|
||||
self.reverse = not self.reverse
|
||||
return self # __aiter__ will be called after, too
|
|
@ -14,8 +14,6 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
|
||||
from async_generator import isasyncgenfunction
|
||||
|
||||
from .client.telegramclient import TelegramClient
|
||||
from .tl.custom import (
|
||||
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
|
||||
|
@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter
|
|||
from .tl.custom.sendergetter import SenderGetter
|
||||
|
||||
|
||||
class _SyncGen:
|
||||
def __init__(self, gen):
|
||||
self.gen = gen
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return asyncio.get_event_loop() \
|
||||
.run_until_complete(self.gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
raise StopIteration from None
|
||||
|
||||
|
||||
def _syncify_wrap(t, method_name, gen):
|
||||
def _syncify_wrap(t, method_name):
|
||||
method = getattr(t, method_name)
|
||||
|
||||
@functools.wraps(method)
|
||||
|
@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen):
|
|||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
return coro
|
||||
elif gen:
|
||||
return _SyncGen(coro)
|
||||
else:
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
|
@ -64,13 +45,14 @@ def syncify(*types):
|
|||
into synchronous, which return either the coroutine or the result
|
||||
based on whether ``asyncio's`` event loop is running.
|
||||
"""
|
||||
# Our asynchronous generators all are `RequestIter`, which already
|
||||
# provide a synchronous iterator variant, so we don't need to worry
|
||||
# about asyncgenfunction's here.
|
||||
for t in types:
|
||||
for name in dir(t):
|
||||
if not name.startswith('_') or name == '__call__':
|
||||
if inspect.iscoroutinefunction(getattr(t, name)):
|
||||
_syncify_wrap(t, name, gen=False)
|
||||
elif isasyncgenfunction(getattr(t, name)):
|
||||
_syncify_wrap(t, name, gen=True)
|
||||
_syncify_wrap(t, name)
|
||||
|
||||
|
||||
syncify(TelegramClient, Draft, Dialog, MessageButton,
|
||||
|
|
Loading…
Reference in New Issue
Block a user