Merge pull request #1115 from LonamiWebs/requestiter

Overhaul asynchronous generators
This commit is contained in:
Lonami 2019-02-27 12:49:12 +01:00 committed by GitHub
commit c99157ade2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 746 additions and 571 deletions

View File

@ -1,3 +1,2 @@
pyaes
rsa
async_generator

View File

@ -214,8 +214,7 @@ def main():
packages=find_packages(exclude=[
'telethon_*', 'run_tests.py', 'try_telethon.py'
]),
install_requires=['pyaes', 'rsa',
'async_generator'],
install_requires=['pyaes', 'rsa'],
extras_require={
'cryptg': ['cryptg']
}

View File

@ -1,19 +1,192 @@
import itertools
import sys
from async_generator import async_generator, yield_
from .users import UserMethods
from .. import utils, helpers
from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
class _ParticipantsIter(RequestIter):
async def _init(self, entity, filter, search, aggressive):
if isinstance(filter, type):
if filter in (types.ChannelParticipantsBanned,
types.ChannelParticipantsKicked,
types.ChannelParticipantsSearch,
types.ChannelParticipantsContacts):
# These require a `q` parameter (support types for convenience)
filter = filter('')
else:
filter = filter()
entity = await self.client.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
self.filter_entity = lambda ent: (
search in utils.get_display_name(ent).lower() or
search in (getattr(ent, 'username', '') or None).lower()
)
else:
self.filter_entity = lambda ent: True
if isinstance(entity, types.InputPeerChannel):
self.total = (await self.client(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if self.limit == 0:
raise StopAsyncIteration
self.seen = set()
if aggressive and not filter:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
self.requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
elif isinstance(entity, types.InputPeerChat):
full = await self.client(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
self.total = 0
raise StopAsyncIteration
self.total = len(full.full_chat.participants.participants)
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user):
continue
user = users[participant.user_id]
user.participant = participant
self.buffer.append(user)
return True
else:
self.total = 1
if self.limit != 0:
user = await self.client.get_entity(entity)
if self.filter_entity(user):
user.participant = None
self.buffer.append(user)
return True
async def _load_next_chunk(self):
if not self.requests:
return True
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
if self.requests[0].offset > self.limit:
return True
results = await self.client(self.requests)
for i in reversed(range(len(self.requests))):
participants = results[i]
if not participants.users:
self.requests.pop(i)
continue
self.requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not self.filter_entity(user) or user.id in self.seen:
continue
self.seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
self.buffer.append(user)
class _AdminLogIter(RequestIter):
async def _init(
self, entity, admins, search, min_id, max_id,
join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete
):
if any((join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete)):
events_filter = types.ChannelAdminLogEventsFilter(
join=join, leave=leave, invite=invite, ban=restrict,
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
demote=demote, info=info, settings=settings, pinned=pinned,
edit=edit, delete=delete
)
else:
events_filter = None
self.entity = await self.client.get_input_entity(entity)
admin_list = []
if admins:
if not utils.is_list_like(admins):
admins = (admins,)
for admin in admins:
admin_list.append(await self.client.get_input_entity(admin))
self.request = functions.channels.GetAdminLogRequest(
self.entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
)
async def _load_next_chunk(self):
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
self.request.max_id = min((e.id for e in r.events), default=0)
for ev in r.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(
self.client, entities, self.entity)
ev.action.new_message._finish_init(
self.client, entities, self.entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(
self.client, entities, self.entity)
self.buffer.append(custom.AdminLogEvent(ev, entities))
if len(r.events) < self.request.limit:
return True
class ChatMethods(UserMethods):
# region Public methods
@async_generator
async def iter_participants(
def iter_participants(
self, entity, limit=None, *, search='',
filter=None, aggressive=False, _total=None):
"""
@ -62,138 +235,23 @@ class ChatMethods(UserMethods):
matched :tl:`ChannelParticipant` type for channels/megagroups
or :tl:`ChatParticipants` for normal chats.
"""
if isinstance(filter, type):
if filter in (types.ChannelParticipantsBanned,
types.ChannelParticipantsKicked,
types.ChannelParticipantsSearch,
types.ChannelParticipantsContacts):
# These require a `q` parameter (support types for convenience)
filter = filter('')
else:
filter = filter()
entity = await self.get_input_entity(entity)
if search and (filter
or not isinstance(entity, types.InputPeerChannel)):
# We need to 'search' ourselves unless we have a PeerChannel
search = search.lower()
def filter_entity(ent):
return search in utils.get_display_name(ent).lower() or\
search in (getattr(ent, 'username', '') or None).lower()
else:
def filter_entity(ent):
return True
limit = float('inf') if limit is None else int(limit)
if isinstance(entity, types.InputPeerChannel):
if _total:
_total[0] = (await self(
functions.channels.GetFullChannelRequest(entity)
)).full_chat.participants_count
if limit == 0:
return
seen = set()
if aggressive and not filter:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=types.ChannelParticipantsSearch(x),
offset=0,
limit=200,
hash=0
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
else:
requests = [functions.channels.GetParticipantsRequest(
channel=entity,
filter=filter or types.ChannelParticipantsSearch(search),
offset=0,
limit=200,
hash=0
)]
while requests:
# Only care about the limit for the first request
# (small amount of people, won't be aggressive).
#
# Most people won't care about getting exactly 12,345
# members so it doesn't really matter not to be 100%
# precise with being out of the offset/limit here.
requests[0].limit = min(limit - requests[0].offset, 200)
if requests[0].offset > limit:
break
results = await self(requests)
for i in reversed(range(len(requests))):
participants = results[i]
if not participants.users:
requests.pop(i)
else:
requests[i].offset += len(participants.participants)
users = {user.id: user for user in participants.users}
for participant in participants.participants:
user = users[participant.user_id]
if not filter_entity(user) or user.id in seen:
continue
seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
await yield_(user)
if len(seen) >= limit:
return
elif isinstance(entity, types.InputPeerChat):
full = await self(
functions.messages.GetFullChatRequest(entity.chat_id))
if not isinstance(
full.full_chat.participants, types.ChatParticipants):
# ChatParticipantsForbidden won't have ``.participants``
if _total:
_total[0] = 0
return
if _total:
_total[0] = len(full.full_chat.participants.participants)
have = 0
users = {user.id: user for user in full.users}
for participant in full.full_chat.participants.participants:
user = users[participant.user_id]
if not filter_entity(user):
continue
have += 1
if have > limit:
break
else:
user = users[participant.user_id]
user.participant = participant
await yield_(user)
else:
if _total:
_total[0] = 1
if limit != 0:
user = await self.get_entity(entity)
if filter_entity(user):
user.participant = None
await yield_(user)
return _ParticipantsIter(
self,
limit,
entity=entity,
filter=filter,
search=search,
aggressive=aggressive
)
async def get_participants(self, *args, **kwargs):
"""
Same as `iter_participants`, but returns a
`TotalList <telethon.helpers.TotalList>` instead.
"""
total = [0]
kwargs['_total'] = total
participants = helpers.TotalList()
async for x in self.iter_participants(*args, **kwargs):
participants.append(x)
participants.total = total[0]
return participants
return await self.iter_participants(*args, **kwargs).collect()
@async_generator
async def iter_admin_log(
def iter_admin_log(
self, entity, limit=None, *, max_id=0, min_id=0, search=None,
admins=None, join=None, leave=None, invite=None, restrict=None,
unrestrict=None, ban=None, unban=None, promote=None, demote=None,
@ -285,66 +343,34 @@ class ChatMethods(UserMethods):
Yields:
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
"""
if limit is None:
limit = sys.maxsize
elif limit <= 0:
return
if any((join, leave, invite, restrict, unrestrict, ban, unban,
promote, demote, info, settings, pinned, edit, delete)):
events_filter = types.ChannelAdminLogEventsFilter(
join=join, leave=leave, invite=invite, ban=restrict,
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
demote=demote, info=info, settings=settings, pinned=pinned,
edit=edit, delete=delete
)
else:
events_filter = None
entity = await self.get_input_entity(entity)
admin_list = []
if admins:
if not utils.is_list_like(admins):
admins = (admins,)
for admin in admins:
admin_list.append(await self.get_input_entity(admin))
request = functions.channels.GetAdminLogRequest(
entity, q=search or '', min_id=min_id, max_id=max_id,
limit=0, events_filter=events_filter, admins=admin_list or None
return _AdminLogIter(
self,
limit,
entity=entity,
admins=admins,
search=search,
min_id=min_id,
max_id=max_id,
join=join,
leave=leave,
invite=invite,
restrict=restrict,
unrestrict=unrestrict,
ban=ban,
unban=unban,
promote=promote,
demote=demote,
info=info,
settings=settings,
pinned=pinned,
edit=edit,
delete=delete
)
while limit > 0:
request.limit = min(limit, 100)
result = await self(request)
limit -= len(result.events)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(result.users, result.chats)}
request.max_id = min((e.id for e in result.events), default=0)
for ev in result.events:
if isinstance(ev.action,
types.ChannelAdminLogEventActionEditMessage):
ev.action.prev_message._finish_init(self, entities, entity)
ev.action.new_message._finish_init(self, entities, entity)
elif isinstance(ev.action,
types.ChannelAdminLogEventActionDeleteMessage):
ev.action.message._finish_init(self, entities, entity)
await yield_(custom.AdminLogEvent(ev, entities))
if len(result.events) < request.limit:
break
async def get_admin_log(self, *args, **kwargs):
"""
Same as `iter_admin_log`, but returns a ``list`` instead.
"""
admin_log = []
async for x in self.iter_admin_log(*args, **kwargs):
admin_log.append(x)
return admin_log
return await self.iter_admin_log(*args, **kwargs).collect()
# endregion

View File

@ -1,18 +1,101 @@
import itertools
from async_generator import async_generator, yield_
from .users import UserMethods
from .. import utils, helpers
from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
class _DialogsIter(RequestIter):
async def _init(
self, offset_date, offset_id, offset_peer, ignore_migrated
):
self.request = functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=1,
hash=0
)
if self.limit == 0:
# Special case, get a single dialog and determine count
dialogs = await self.client(self.request)
self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
raise StopAsyncIteration
self.seen = set()
self.offset_date = offset_date
self.ignore_migrated = ignore_migrated
async def _load_next_chunk(self):
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
self.total = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
for d in r.dialogs:
# We check the offset date here because Telegram may ignore it
if self.offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > self.offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in self.seen:
self.seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self.client._channel_pts[cd.id] = cd.dialog.pts
if not self.ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
self.buffer.append(cd)
if len(r.dialogs) < self.request.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
return True
if self.request.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
return True
self.request.offset_id = r.messages[-1].id
self.request.exclude_pinned = True
self.request.offset_date = r.messages[-1].date
self.request.offset_peer =\
entities[utils.get_peer_id(r.dialogs[-1].peer)]
class _DraftsIter(RequestIter):
async def _init(self, **kwargs):
r = await self.client(functions.messages.GetAllDraftsRequest())
self.buffer.extend(custom.Draft._from_update(self.client, u)
for u in r.updates)
async def _load_next_chunk(self):
return []
class DialogMethods(UserMethods):
# region Public methods
@async_generator
async def iter_dialogs(
def iter_dialogs(
self, limit=None, *, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
_total=None):
@ -50,99 +133,23 @@ class DialogMethods(UserMethods):
Yields:
Instances of `telethon.tl.custom.dialog.Dialog`.
"""
limit = float('inf') if limit is None else int(limit)
if limit == 0:
if not _total:
return
# Special case, get a single dialog and determine count
dialogs = await self(functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=1,
hash=0
))
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
return
seen = set()
req = functions.messages.GetDialogsRequest(
return _DialogsIter(
self,
limit,
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=0,
hash=0
ignore_migrated=ignore_migrated
)
while len(seen) < limit:
req.limit = min(limit - len(seen), 100)
r = await self(req)
if _total:
_total[0] = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
# Happens when there are pinned dialogs
if len(r.dialogs) > limit:
r.dialogs = r.dialogs[:limit]
for d in r.dialogs:
if offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in seen:
seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self._channel_pts[cd.id] = cd.dialog.pts
if not ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
await yield_(cd)
if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
break
req.offset_date = r.messages[-1].date
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
if req.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
break
req.offset_id = r.messages[-1].id
req.exclude_pinned = True
async def get_dialogs(self, *args, **kwargs):
"""
Same as `iter_dialogs`, but returns a
`TotalList <telethon.helpers.TotalList>` instead.
"""
total = [0]
kwargs['_total'] = total
dialogs = helpers.TotalList()
async for x in self.iter_dialogs(*args, **kwargs):
dialogs.append(x)
dialogs.total = total[0]
return dialogs
return await self.iter_dialogs(*args, **kwargs).collect()
@async_generator
async def iter_drafts(self):
def iter_drafts(self):
"""
Iterator over all open draft messages.
@ -151,18 +158,14 @@ class DialogMethods(UserMethods):
to change the message or `telethon.tl.custom.draft.Draft.delete`
among other things.
"""
r = await self(functions.messages.GetAllDraftsRequest())
for update in r.updates:
await yield_(custom.Draft._from_update(self, update))
# TODO Passing a limit here makes no sense
return _DraftsIter(self, None)
async def get_drafts(self):
"""
Same as :meth:`iter_drafts`, but returns a list instead.
"""
result = []
async for x in self.iter_drafts():
result.append(x)
return result
return await self.iter_drafts().collect()
def conversation(
self, entity,

View File

@ -1,14 +1,282 @@
import asyncio
import itertools
import time
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods
from .uploads import UploadMethods
from .buttons import ButtonMethods
from .. import helpers, utils, errors
from ..tl import types, functions
from ..requestiter import RequestIter
# TODO Maybe RequestIter could rather have the update offset here?
# Maybe init should return the request to be used and it be
# called automatically? And another method to just process it.
class _MessagesIter(RequestIter):
"""
Common factor for all requests that need to iterate over messages.
"""
async def _init(
self, entity, offset_id, min_id, max_id, from_user,
batch_size, offset_date, add_offset, filter, search
):
# Note that entity being ``None`` will perform a global search.
if entity:
self.entity = await self.client.get_input_entity(entity)
else:
self.entity = None
if self.reverse:
raise ValueError('Cannot reverse global search')
# Telegram doesn't like min_id/max_id. If these IDs are low enough
# (starting from last_id - 100), the request will return nothing.
#
# We can emulate their behaviour locally by setting offset = max_id
# and simply stopping once we hit a message with ID <= min_id.
if self.reverse:
offset_id = max(offset_id, min_id)
if offset_id and max_id:
if max_id - offset_id <= 1:
raise StopAsyncIteration
if not max_id:
max_id = float('inf')
else:
offset_id = max(offset_id, max_id)
if offset_id and min_id:
if offset_id - min_id <= 1:
raise StopAsyncIteration
if self.reverse:
if offset_id:
offset_id += 1
else:
offset_id = 1
if from_user:
from_user = await self.client.get_input_entity(from_user)
if not isinstance(from_user, (
types.InputPeerUser, types.InputPeerSelf)):
from_user = None # Ignore from_user unless it's a user
if from_user:
self.from_id = await self.client.get_peer_id(from_user)
else:
self.from_id = None
if not self.entity:
self.request = functions.messages.SearchGlobalRequest(
q=search or '',
offset_date=offset_date,
offset_peer=types.InputPeerEmpty(),
offset_id=offset_id,
limit=1
)
elif search is not None or filter or from_user:
if filter is None:
filter = types.InputMessagesFilterEmpty()
# Telegram completely ignores `from_id` in private chats
if isinstance(
self.entity, (types.InputPeerUser, types.InputPeerSelf)):
# Don't bother sending `from_user` (it's ignored anyway),
# but keep `from_id` defined above to check it locally.
from_user = None
else:
# Do send `from_user` to do the filtering server-side,
# and set `from_id` to None to avoid checking it locally.
self.from_id = None
self.request = functions.messages.SearchRequest(
peer=self.entity,
q=search or '',
filter=filter() if isinstance(filter, type) else filter,
min_date=None,
max_date=offset_date,
offset_id=offset_id,
add_offset=add_offset,
limit=0, # Search actually returns 0 items if we ask it to
max_id=0,
min_id=0,
hash=0,
from_id=from_user
)
else:
self.request = functions.messages.GetHistoryRequest(
peer=self.entity,
limit=1,
offset_date=offset_date,
offset_id=offset_id,
min_id=0,
max_id=0,
add_offset=add_offset,
hash=0
)
if self.limit == 0:
# No messages, but we still need to know the total message count
result = await self.client(self.request)
if isinstance(result, types.messages.MessagesNotModified):
self.total = result.count
else:
self.total = getattr(result, 'count', len(result.messages))
raise StopAsyncIteration
if self.wait_time is None:
self.wait_time = 1 if self.limit > 3000 else 0
# Telegram has a hard limit of 100.
# We don't need to fetch 100 if the limit is less.
self.batch_size = min(max(batch_size, 1), min(100, self.limit))
# When going in reverse we need an offset of `-limit`, but we
# also want to respect what the user passed, so add them together.
if self.reverse:
self.request.add_offset -= self.batch_size
self.add_offset = add_offset
self.max_id = max_id
self.min_id = min_id
self.last_id = 0 if self.reverse else float('inf')
async def _load_next_chunk(self):
self.request.limit = min(self.left, self.batch_size)
if self.reverse and self.request.limit != self.batch_size:
# Remember that we need -limit when going in reverse
self.request.add_offset = self.add_offset - self.request.limit
r = await self.client(self.request)
self.total = getattr(r, 'count', len(r.messages))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = reversed(r.messages) if self.reverse else r.messages
for message in messages:
if (isinstance(message, types.MessageEmpty)
or self.from_id and message.from_id != self.from_id):
continue
if not self._message_in_range(message):
return True
# There has been reports that on bad connections this method
# was returning duplicated IDs sometimes. Using ``last_id``
# is an attempt to avoid these duplicates, since the message
# IDs are returned in descending order (or asc if reverse).
self.last_id = message.id
message._finish_init(self.client, entities, self.entity)
self.buffer.append(message)
if len(r.messages) < self.request.limit:
return True
# Get the last message that's not empty (in some rare cases
# it can happen that the last message is :tl:`MessageEmpty`)
if self.buffer:
self._update_offset(self.buffer[-1])
else:
# There are some cases where all the messages we get start
# being empty. This can happen on migrated mega-groups if
# the history was cleared, and we're using search. Telegram
# acts incredibly weird sometimes. Messages are returned but
# only "empty", not their contents. If this is the case we
# should just give up since there won't be any new Message.
return True
def _message_in_range(self, message):
"""
Determine whether the given message is in the range or
it should be ignored (and avoid loading more chunks).
"""
# No entity means message IDs between chats may vary
if self.entity:
if self.reverse:
if message.id <= self.last_id or message.id >= self.max_id:
return False
else:
if message.id >= self.last_id or message.id <= self.min_id:
return False
return True
def _update_offset(self, last_message):
"""
After making the request, update its offset with the last message.
"""
self.request.offset_id = last_message.id
if self.reverse:
# We want to skip the one we already have
self.request.offset_id += 1
if isinstance(self.request, functions.messages.SearchRequest):
# Unlike getHistory and searchGlobal that use *offset* date,
# this is *max* date. This means that doing a search in reverse
# will break it. Since it's not really needed once we're going
# (only for the first request), it's safe to just clear it off.
self.request.max_date = None
else:
# getHistory and searchGlobal call it offset_date
self.request.offset_date = last_message.date
if isinstance(self.request, functions.messages.SearchGlobalRequest):
self.request.offset_peer = last_message.input_chat
class _IDsIter(RequestIter):
async def _init(self, entity, ids):
# TODO We never actually split IDs in chunks, but maybe we should
if not utils.is_list_like(ids):
self.ids = [ids]
elif not ids:
raise StopAsyncIteration
elif self.reverse:
self.ids = list(reversed(ids))
else:
self.ids = ids
if entity:
entity = await self.client.get_input_entity(entity)
self.total = len(ids)
from_id = None # By default, no need to validate from_id
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
try:
r = await self.client(
functions.channels.GetMessagesRequest(entity, ids))
except errors.MessageIdsEmptyError:
# All IDs were invalid, use a dummy result
r = types.messages.MessagesNotModified(len(ids))
else:
r = await self.client(functions.messages.GetMessagesRequest(ids))
if entity:
from_id = await self.client.get_peer_id(entity)
if isinstance(r, types.messages.MessagesNotModified):
self.buffer.extend(None for _ in ids)
return
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
# Telegram seems to return the messages in the order in which
# we asked them for, so we don't need to check it ourselves,
# unless some messages were invalid in which case Telegram
# may decide to not send them at all.
#
# The passed message IDs may not belong to the desired entity
# since the user can enter arbitrary numbers which can belong to
# arbitrary chats. Validate these unless ``from_id is None``.
for message in r.messages:
if isinstance(message, types.MessageEmpty) or (
from_id and message.chat_id != from_id):
self.buffer.append(None)
else:
message._finish_init(self.client, entities, entity)
self.buffer.append(message)
async def _load_next_chunk(self):
return True # no next chunk, all done in init
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
@ -17,8 +285,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
# region Message retrieval
@async_generator
async def iter_messages(
def iter_messages(
self, entity, limit=None, *, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
from_user=None, batch_size=100, wait_time=None, ids=None,
@ -133,208 +400,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
an higher limit, so you're free to set the ``batch_size`` that
you think may be good.
"""
# Note that entity being ``None`` is intended to get messages by
# ID under no specific chat, and also to request a global search.
if entity:
entity = await self.get_input_entity(entity)
if ids:
if not utils.is_list_like(ids):
ids = (ids,)
if reverse:
ids = list(reversed(ids))
async for x in self._iter_ids(entity, ids, total=_total):
await yield_(x)
return
if ids is not None:
return _IDsIter(self, limit, entity=entity, ids=ids)
# Telegram doesn't like min_id/max_id. If these IDs are low enough
# (starting from last_id - 100), the request will return nothing.
#
# We can emulate their behaviour locally by setting offset = max_id
# and simply stopping once we hit a message with ID <= min_id.
if reverse:
offset_id = max(offset_id, min_id)
if offset_id and max_id:
if max_id - offset_id <= 1:
return
if not max_id:
max_id = float('inf')
else:
offset_id = max(offset_id, max_id)
if offset_id and min_id:
if offset_id - min_id <= 1:
return
if reverse:
if offset_id:
offset_id += 1
else:
offset_id = 1
if from_user:
from_user = await self.get_input_entity(from_user)
if not isinstance(from_user, (
types.InputPeerUser, types.InputPeerSelf)):
from_user = None # Ignore from_user unless it's a user
from_id = (await self.get_peer_id(from_user)) if from_user else None
limit = float('inf') if limit is None else int(limit)
if not entity:
if reverse:
raise ValueError('Cannot reverse global search')
reverse = None
request = functions.messages.SearchGlobalRequest(
q=search or '',
offset_date=offset_date,
offset_peer=types.InputPeerEmpty(),
offset_id=offset_id,
limit=1
)
elif search is not None or filter or from_user:
if filter is None:
filter = types.InputMessagesFilterEmpty()
# Telegram completely ignores `from_id` in private chats
if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)):
# Don't bother sending `from_user` (it's ignored anyway),
# but keep `from_id` defined above to check it locally.
from_user = None
else:
# Do send `from_user` to do the filtering server-side,
# and set `from_id` to None to avoid checking it locally.
from_id = None
request = functions.messages.SearchRequest(
peer=entity,
q=search or '',
filter=filter() if isinstance(filter, type) else filter,
min_date=None,
max_date=offset_date,
offset_id=offset_id,
add_offset=add_offset,
limit=0, # Search actually returns 0 items if we ask it to
max_id=0,
min_id=0,
hash=0,
from_id=from_user
)
else:
request = functions.messages.GetHistoryRequest(
peer=entity,
limit=1,
offset_date=offset_date,
offset_id=offset_id,
min_id=0,
max_id=0,
add_offset=add_offset,
hash=0
)
if limit == 0:
if not _total:
return
# No messages, but we still need to know the total message count
result = await self(request)
if isinstance(result, types.messages.MessagesNotModified):
_total[0] = result.count
else:
_total[0] = getattr(result, 'count', len(result.messages))
return
if wait_time is None:
wait_time = 1 if limit > 3000 else 0
have = 0
last_id = 0 if reverse else float('inf')
# Telegram has a hard limit of 100.
# We don't need to fetch 100 if the limit is less.
batch_size = min(max(batch_size, 1), min(100, limit))
# When going in reverse we need an offset of `-limit`, but we
# also want to respect what the user passed, so add them together.
if reverse:
request.add_offset -= batch_size
while have < limit:
start = time.time()
request.limit = min(limit - have, batch_size)
if reverse and request.limit != batch_size:
# Remember that we need -limit when going in reverse
request.add_offset = add_offset - request.limit
r = await self(request)
if _total:
_total[0] = getattr(r, 'count', len(r.messages))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = reversed(r.messages) if reverse else r.messages
for message in messages:
if (isinstance(message, types.MessageEmpty)
or from_id and message.from_id != from_id):
continue
if reverse is None:
pass
elif reverse:
if message.id <= last_id or message.id >= max_id:
return
else:
if message.id >= last_id or message.id <= min_id:
return
# There has been reports that on bad connections this method
# was returning duplicated IDs sometimes. Using ``last_id``
# is an attempt to avoid these duplicates, since the message
# IDs are returned in descending order (or asc if reverse).
last_id = message.id
message._finish_init(self, entities, entity)
await yield_(message)
have += 1
if len(r.messages) < request.limit:
break
# Find the first message that's not empty (in some rare cases
# it can happen that the last message is :tl:`MessageEmpty`)
last_message = None
messages = r.messages if reverse else reversed(r.messages)
for m in messages:
if not isinstance(m, types.MessageEmpty):
last_message = m
break
if last_message is None:
# There are some cases where all the messages we get start
# being empty. This can happen on migrated mega-groups if
# the history was cleared, and we're using search. Telegram
# acts incredibly weird sometimes. Messages are returned but
# only "empty", not their contents. If this is the case we
# should just give up since there won't be any new Message.
break
else:
request.offset_id = last_message.id
if isinstance(request, functions.messages.SearchRequest):
request.max_date = last_message.date
else:
# getHistory and searchGlobal call it offset_date
request.offset_date = last_message.date
if isinstance(request, functions.messages.SearchGlobalRequest):
request.offset_peer = last_message.input_chat
elif reverse:
# We want to skip the one we already have
request.offset_id += 1
await asyncio.sleep(
max(wait_time - (time.time() - start), 0), loop=self._loop)
return _MessagesIter(
client=self,
reverse=reverse,
wait_time=wait_time,
limit=limit,
entity=entity,
offset_id=offset_id,
min_id=min_id,
max_id=max_id,
from_user=from_user,
batch_size=batch_size,
offset_date=offset_date,
add_offset=add_offset,
filter=filter,
search=search
)
async def get_messages(self, *args, **kwargs):
"""
@ -353,23 +438,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
a single `Message <telethon.tl.custom.message.Message>` will be
returned for convenience instead of a list.
"""
total = [0]
kwargs['_total'] = total
if len(args) == 1 and 'limit' not in kwargs:
if 'min_id' in kwargs and 'max_id' in kwargs:
kwargs['limit'] = None
else:
kwargs['limit'] = 1
msgs = helpers.TotalList()
async for x in self.iter_messages(*args, **kwargs):
msgs.append(x)
msgs.total = total[0]
if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
# Check for empty list to handle InputMessageReplyTo
return msgs[0] if msgs else None
it = self.iter_messages(*args, **kwargs)
return msgs
ids = kwargs.get('ids')
if ids and not utils.is_list_like(ids):
async for message in it:
return message
else:
# Iterator exhausted = empty, to handle InputMessageReplyTo
return None
return await it.collect()
# endregion
@ -799,52 +884,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
# endregion
# endregion
# region Private methods
@async_generator
async def _iter_ids(self, entity, ids, total):
"""
Special case for `iter_messages` when it should only fetch some IDs.
"""
if total:
total[0] = len(ids)
from_id = None # By default, no need to validate from_id
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
try:
r = await self(
functions.channels.GetMessagesRequest(entity, ids))
except errors.MessageIdsEmptyError:
# All IDs were invalid, use a dummy result
r = types.messages.MessagesNotModified(len(ids))
else:
r = await self(functions.messages.GetMessagesRequest(ids))
if entity:
from_id = utils.get_peer_id(entity)
if isinstance(r, types.messages.MessagesNotModified):
for _ in ids:
await yield_(None)
return
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
# Telegram seems to return the messages in the order in which
# we asked them for, so we don't need to check it ourselves,
# unless some messages were invalid in which case Telegram
# may decide to not send them at all.
#
# The passed message IDs may not belong to the desired entity
# since the user can enter arbitrary numbers which can belong to
# arbitrary chats. Validate these unless ``from_id is None``.
for message in r.messages:
if isinstance(message, types.MessageEmpty) or (
from_id and message.chat_id != from_id):
await yield_(None)
else:
message._finish_init(self, entities, entity)
await yield_(message)
# endregion

130
telethon/requestiter.py Normal file
View File

@ -0,0 +1,130 @@
import abc
import asyncio
import time
from . import helpers
# TODO There are two types of iterators for requests.
# One has a limit of items to retrieve, and the
# other has a list that must be called in chunks.
# Make classes for both here so it's easy to use.
class RequestIter(abc.ABC):
"""
Helper class to deal with requests that need offsets to iterate.
It has some facilities, such as automatically sleeping a desired
amount of time between requests if needed (but not more).
Can be used synchronously if the event loop is not running and
as an asynchronous iterator otherwise.
`limit` is the total amount of items that the iterator should return.
This is handled on this base class, and will be always ``>= 0``.
`left` will be reset every time the iterator is used and will indicate
the amount of items that should be emitted left, so that subclasses can
be more efficient and fetch only as many items as they need.
Iterators may be used with ``reversed``, and their `reverse` flag will
be set to ``True`` if that's the case. Note that if this flag is set,
`buffer` should be filled in reverse too.
"""
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
self.client = client
self.reverse = reverse
self.wait_time = wait_time
self.kwargs = kwargs
self.limit = max(float('inf') if limit is None else limit, 0)
self.left = None
self.buffer = None
self.index = None
self.total = None
self.last_load = None
async def _init(self, **kwargs):
"""
Called when asynchronous initialization is necessary. All keyword
arguments passed to `__init__` will be forwarded here, and it's
preferable to use named arguments in the subclasses without defaults
to avoid forgetting or misspelling any of them.
This method may ``raise StopAsyncIteration`` if it cannot continue.
This method may actually fill the initial buffer if it needs to.
"""
async def __anext__(self):
if self.buffer is None:
self.buffer = []
await self._init(**self.kwargs)
if self.left <= 0: # <= 0 because subclasses may change it
raise StopAsyncIteration
if self.index == len(self.buffer):
# asyncio will handle times <= 0 to sleep 0 seconds
if self.wait_time:
await asyncio.sleep(
self.wait_time - (time.time() - self.last_load),
loop=self.client.loop
)
self.last_load = time.time()
self.index = 0
self.buffer = []
if await self._load_next_chunk():
self.left = len(self.buffer)
if not self.buffer:
raise StopAsyncIteration
result = self.buffer[self.index]
self.left -= 1
self.index += 1
return result
def __aiter__(self):
self.buffer = None
self.index = 0
self.last_load = 0
self.left = self.limit
return self
def __iter__(self):
if self.client.loop.is_running():
raise RuntimeError(
'You must use "async for" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.__aiter__()
async def collect(self):
"""
Create a `self` iterator and collect it into a `TotalList`
(a normal list with a `.total` attribute).
"""
result = helpers.TotalList()
async for message in self:
result.append(message)
result.total = self.total
return result
@abc.abstractmethod
async def _load_next_chunk(self):
"""
Called when the next chunk is necessary.
It should extend the `buffer` with new items.
It should return ``True`` if it's the last chunk,
after which moment the method won't be called again
during the same iteration.
"""
raise NotImplementedError
def __reversed__(self):
self.reverse = not self.reverse
return self # __aiter__ will be called after, too

View File

@ -14,8 +14,6 @@ import asyncio
import functools
import inspect
from async_generator import isasyncgenfunction
from .client.telegramclient import TelegramClient
from .tl.custom import (
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter
from .tl.custom.sendergetter import SenderGetter
class _SyncGen:
def __init__(self, gen):
self.gen = gen
def __iter__(self):
return self
def __next__(self):
try:
return asyncio.get_event_loop() \
.run_until_complete(self.gen.__anext__())
except StopAsyncIteration:
raise StopIteration from None
def _syncify_wrap(t, method_name, gen):
def _syncify_wrap(t, method_name):
method = getattr(t, method_name)
@functools.wraps(method)
@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen):
loop = asyncio.get_event_loop()
if loop.is_running():
return coro
elif gen:
return _SyncGen(coro)
else:
return loop.run_until_complete(coro)
@ -64,13 +45,14 @@ def syncify(*types):
into synchronous, which return either the coroutine or the result
based on whether ``asyncio's`` event loop is running.
"""
# Our asynchronous generators all are `RequestIter`, which already
# provide a synchronous iterator variant, so we don't need to worry
# about asyncgenfunction's here.
for t in types:
for name in dir(t):
if not name.startswith('_') or name == '__call__':
if inspect.iscoroutinefunction(getattr(t, name)):
_syncify_wrap(t, name, gen=False)
elif isasyncgenfunction(getattr(t, name)):
_syncify_wrap(t, name, gen=True)
_syncify_wrap(t, name)
syncify(TelegramClient, Draft, Dialog, MessageButton,