diff --git a/requirements.txt b/requirements.txt index 43e88e96..2b650ec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ pyaes rsa -async_generator diff --git a/setup.py b/setup.py index cec63880..1456cbcb 100755 --- a/setup.py +++ b/setup.py @@ -214,8 +214,7 @@ def main(): packages=find_packages(exclude=[ 'telethon_*', 'run_tests.py', 'try_telethon.py' ]), - install_requires=['pyaes', 'rsa', - 'async_generator'], + install_requires=['pyaes', 'rsa'], extras_require={ 'cryptg': ['cryptg'] } diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 6578c785..83f8ce5e 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -1,19 +1,192 @@ import itertools -import sys - -from async_generator import async_generator, yield_ from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _ParticipantsIter(RequestIter): + async def _init(self, entity, filter, search, aggressive): + if isinstance(filter, type): + if filter in (types.ChannelParticipantsBanned, + types.ChannelParticipantsKicked, + types.ChannelParticipantsSearch, + types.ChannelParticipantsContacts): + # These require a `q` parameter (support types for convenience) + filter = filter('') + else: + filter = filter() + + entity = await self.client.get_input_entity(entity) + if search and (filter + or not isinstance(entity, types.InputPeerChannel)): + # We need to 'search' ourselves unless we have a PeerChannel + search = search.lower() + + self.filter_entity = lambda ent: ( + search in utils.get_display_name(ent).lower() or + search in (getattr(ent, 'username', '') or None).lower() + ) + else: + self.filter_entity = lambda ent: True + + if isinstance(entity, types.InputPeerChannel): + self.total = (await self.client( + functions.channels.GetFullChannelRequest(entity) + )).full_chat.participants_count + + if self.limit == 0: + raise StopAsyncIteration + + self.seen = set() + if aggressive and not filter: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=types.ChannelParticipantsSearch(x), + offset=0, + limit=200, + hash=0 + ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] + else: + self.requests = [functions.channels.GetParticipantsRequest( + channel=entity, + filter=filter or types.ChannelParticipantsSearch(search), + offset=0, + limit=200, + hash=0 + )] + + elif isinstance(entity, types.InputPeerChat): + full = await self.client( + functions.messages.GetFullChatRequest(entity.chat_id)) + if not isinstance( + full.full_chat.participants, types.ChatParticipants): + # ChatParticipantsForbidden won't have ``.participants`` + self.total = 0 + raise StopAsyncIteration + + self.total = len(full.full_chat.participants.participants) + + users = {user.id: user for user in full.users} + for participant in full.full_chat.participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user): + continue + + user = users[participant.user_id] + user.participant = participant + self.buffer.append(user) + + return True + else: + self.total = 1 + if self.limit != 0: + user = await self.client.get_entity(entity) + if self.filter_entity(user): + user.participant = None + self.buffer.append(user) + + return True + + async def _load_next_chunk(self): + if not self.requests: + return True + + # Only care about the limit for the first request + # (small amount of people, won't be aggressive). + # + # Most people won't care about getting exactly 12,345 + # members so it doesn't really matter not to be 100% + # precise with being out of the offset/limit here. + self.requests[0].limit = min(self.limit - self.requests[0].offset, 200) + if self.requests[0].offset > self.limit: + return True + + results = await self.client(self.requests) + for i in reversed(range(len(self.requests))): + participants = results[i] + if not participants.users: + self.requests.pop(i) + continue + + self.requests[i].offset += len(participants.participants) + users = {user.id: user for user in participants.users} + for participant in participants.participants: + user = users[participant.user_id] + if not self.filter_entity(user) or user.id in self.seen: + continue + + self.seen.add(participant.user_id) + user = users[participant.user_id] + user.participant = participant + self.buffer.append(user) + + +class _AdminLogIter(RequestIter): + async def _init( + self, entity, admins, search, min_id, max_id, + join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete + ): + if any((join, leave, invite, restrict, unrestrict, ban, unban, + promote, demote, info, settings, pinned, edit, delete)): + events_filter = types.ChannelAdminLogEventsFilter( + join=join, leave=leave, invite=invite, ban=restrict, + unban=unrestrict, kick=ban, unkick=unban, promote=promote, + demote=demote, info=info, settings=settings, pinned=pinned, + edit=edit, delete=delete + ) + else: + events_filter = None + + self.entity = await self.client.get_input_entity(entity) + + admin_list = [] + if admins: + if not utils.is_list_like(admins): + admins = (admins,) + + for admin in admins: + admin_list.append(await self.client.get_input_entity(admin)) + + self.request = functions.channels.GetAdminLogRequest( + self.entity, q=search or '', min_id=min_id, max_id=max_id, + limit=0, events_filter=events_filter, admins=admin_list or None + ) + + async def _load_next_chunk(self): + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + self.request.max_id = min((e.id for e in r.events), default=0) + for ev in r.events: + if isinstance(ev.action, + types.ChannelAdminLogEventActionEditMessage): + ev.action.prev_message._finish_init( + self.client, entities, self.entity) + + ev.action.new_message._finish_init( + self.client, entities, self.entity) + + elif isinstance(ev.action, + types.ChannelAdminLogEventActionDeleteMessage): + ev.action.message._finish_init( + self.client, entities, self.entity) + + self.buffer.append(custom.AdminLogEvent(ev, entities)) + + if len(r.events) < self.request.limit: + return True + + class ChatMethods(UserMethods): # region Public methods - @async_generator - async def iter_participants( + def iter_participants( self, entity, limit=None, *, search='', filter=None, aggressive=False, _total=None): """ @@ -62,138 +235,23 @@ class ChatMethods(UserMethods): matched :tl:`ChannelParticipant` type for channels/megagroups or :tl:`ChatParticipants` for normal chats. """ - if isinstance(filter, type): - if filter in (types.ChannelParticipantsBanned, - types.ChannelParticipantsKicked, - types.ChannelParticipantsSearch, - types.ChannelParticipantsContacts): - # These require a `q` parameter (support types for convenience) - filter = filter('') - else: - filter = filter() - - entity = await self.get_input_entity(entity) - if search and (filter - or not isinstance(entity, types.InputPeerChannel)): - # We need to 'search' ourselves unless we have a PeerChannel - search = search.lower() - - def filter_entity(ent): - return search in utils.get_display_name(ent).lower() or\ - search in (getattr(ent, 'username', '') or None).lower() - else: - def filter_entity(ent): - return True - - limit = float('inf') if limit is None else int(limit) - if isinstance(entity, types.InputPeerChannel): - if _total: - _total[0] = (await self( - functions.channels.GetFullChannelRequest(entity) - )).full_chat.participants_count - - if limit == 0: - return - - seen = set() - if aggressive and not filter: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=types.ChannelParticipantsSearch(x), - offset=0, - limit=200, - hash=0 - ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))] - else: - requests = [functions.channels.GetParticipantsRequest( - channel=entity, - filter=filter or types.ChannelParticipantsSearch(search), - offset=0, - limit=200, - hash=0 - )] - - while requests: - # Only care about the limit for the first request - # (small amount of people, won't be aggressive). - # - # Most people won't care about getting exactly 12,345 - # members so it doesn't really matter not to be 100% - # precise with being out of the offset/limit here. - requests[0].limit = min(limit - requests[0].offset, 200) - if requests[0].offset > limit: - break - - results = await self(requests) - for i in reversed(range(len(requests))): - participants = results[i] - if not participants.users: - requests.pop(i) - else: - requests[i].offset += len(participants.participants) - users = {user.id: user for user in participants.users} - for participant in participants.participants: - user = users[participant.user_id] - if not filter_entity(user) or user.id in seen: - continue - - seen.add(participant.user_id) - user = users[participant.user_id] - user.participant = participant - await yield_(user) - if len(seen) >= limit: - return - - elif isinstance(entity, types.InputPeerChat): - full = await self( - functions.messages.GetFullChatRequest(entity.chat_id)) - if not isinstance( - full.full_chat.participants, types.ChatParticipants): - # ChatParticipantsForbidden won't have ``.participants`` - if _total: - _total[0] = 0 - return - - if _total: - _total[0] = len(full.full_chat.participants.participants) - - have = 0 - users = {user.id: user for user in full.users} - for participant in full.full_chat.participants.participants: - user = users[participant.user_id] - if not filter_entity(user): - continue - have += 1 - if have > limit: - break - else: - user = users[participant.user_id] - user.participant = participant - await yield_(user) - else: - if _total: - _total[0] = 1 - if limit != 0: - user = await self.get_entity(entity) - if filter_entity(user): - user.participant = None - await yield_(user) + return _ParticipantsIter( + self, + limit, + entity=entity, + filter=filter, + search=search, + aggressive=aggressive + ) async def get_participants(self, *args, **kwargs): """ Same as `iter_participants`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - participants = helpers.TotalList() - async for x in self.iter_participants(*args, **kwargs): - participants.append(x) - participants.total = total[0] - return participants + return await self.iter_participants(*args, **kwargs).collect() - @async_generator - async def iter_admin_log( + def iter_admin_log( self, entity, limit=None, *, max_id=0, min_id=0, search=None, admins=None, join=None, leave=None, invite=None, restrict=None, unrestrict=None, ban=None, unban=None, promote=None, demote=None, @@ -285,66 +343,34 @@ class ChatMethods(UserMethods): Yields: Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`. """ - if limit is None: - limit = sys.maxsize - elif limit <= 0: - return - - if any((join, leave, invite, restrict, unrestrict, ban, unban, - promote, demote, info, settings, pinned, edit, delete)): - events_filter = types.ChannelAdminLogEventsFilter( - join=join, leave=leave, invite=invite, ban=restrict, - unban=unrestrict, kick=ban, unkick=unban, promote=promote, - demote=demote, info=info, settings=settings, pinned=pinned, - edit=edit, delete=delete - ) - else: - events_filter = None - - entity = await self.get_input_entity(entity) - - admin_list = [] - if admins: - if not utils.is_list_like(admins): - admins = (admins,) - - for admin in admins: - admin_list.append(await self.get_input_entity(admin)) - - request = functions.channels.GetAdminLogRequest( - entity, q=search or '', min_id=min_id, max_id=max_id, - limit=0, events_filter=events_filter, admins=admin_list or None + return _AdminLogIter( + self, + limit, + entity=entity, + admins=admins, + search=search, + min_id=min_id, + max_id=max_id, + join=join, + leave=leave, + invite=invite, + restrict=restrict, + unrestrict=unrestrict, + ban=ban, + unban=unban, + promote=promote, + demote=demote, + info=info, + settings=settings, + pinned=pinned, + edit=edit, + delete=delete ) - while limit > 0: - request.limit = min(limit, 100) - result = await self(request) - limit -= len(result.events) - entities = {utils.get_peer_id(x): x - for x in itertools.chain(result.users, result.chats)} - - request.max_id = min((e.id for e in result.events), default=0) - for ev in result.events: - if isinstance(ev.action, - types.ChannelAdminLogEventActionEditMessage): - ev.action.prev_message._finish_init(self, entities, entity) - ev.action.new_message._finish_init(self, entities, entity) - - elif isinstance(ev.action, - types.ChannelAdminLogEventActionDeleteMessage): - ev.action.message._finish_init(self, entities, entity) - - await yield_(custom.AdminLogEvent(ev, entities)) - - if len(result.events) < request.limit: - break async def get_admin_log(self, *args, **kwargs): """ Same as `iter_admin_log`, but returns a ``list`` instead. """ - admin_log = [] - async for x in self.iter_admin_log(*args, **kwargs): - admin_log.append(x) - return admin_log + return await self.iter_admin_log(*args, **kwargs).collect() # endregion diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index 6c7edd24..453dfdec 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -1,18 +1,101 @@ import itertools -from async_generator import async_generator, yield_ - from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _DialogsIter(RequestIter): + async def _init( + self, offset_date, offset_id, offset_peer, ignore_migrated + ): + self.request = functions.messages.GetDialogsRequest( + offset_date=offset_date, + offset_id=offset_id, + offset_peer=offset_peer, + limit=1, + hash=0 + ) + + if self.limit == 0: + # Special case, get a single dialog and determine count + dialogs = await self.client(self.request) + self.total = getattr(dialogs, 'count', len(dialogs.dialogs)) + raise StopAsyncIteration + + self.seen = set() + self.offset_date = offset_date + self.ignore_migrated = ignore_migrated + + async def _load_next_chunk(self): + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + + self.total = getattr(r, 'count', len(r.dialogs)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = {} + for m in r.messages: + m._finish_init(self, entities, None) + messages[m.id] = m + + for d in r.dialogs: + # We check the offset date here because Telegram may ignore it + if self.offset_date: + date = getattr(messages.get( + d.top_message, None), 'date', None) + + if not date or date.timestamp() > self.offset_date.timestamp(): + continue + + peer_id = utils.get_peer_id(d.peer) + if peer_id not in self.seen: + self.seen.add(peer_id) + cd = custom.Dialog(self, d, entities, messages) + if cd.dialog.pts: + self.client._channel_pts[cd.id] = cd.dialog.pts + + if not self.ignore_migrated or getattr( + cd.entity, 'migrated_to', None) is None: + self.buffer.append(cd) + + if len(r.dialogs) < self.request.limit\ + or not isinstance(r, types.messages.DialogsSlice): + # Less than we requested means we reached the end, or + # we didn't get a DialogsSlice which means we got all. + return True + + if self.request.offset_id == r.messages[-1].id: + # In some very rare cases this will get stuck in an infinite + # loop, where the offsets will get reused over and over. If + # the new offset is the same as the one before, break already. + return True + + self.request.offset_id = r.messages[-1].id + self.request.exclude_pinned = True + self.request.offset_date = r.messages[-1].date + self.request.offset_peer =\ + entities[utils.get_peer_id(r.dialogs[-1].peer)] + + +class _DraftsIter(RequestIter): + async def _init(self, **kwargs): + r = await self.client(functions.messages.GetAllDraftsRequest()) + self.buffer.extend(custom.Draft._from_update(self.client, u) + for u in r.updates) + + async def _load_next_chunk(self): + return [] + + class DialogMethods(UserMethods): # region Public methods - @async_generator - async def iter_dialogs( + def iter_dialogs( self, limit=None, *, offset_date=None, offset_id=0, offset_peer=types.InputPeerEmpty(), ignore_migrated=False, _total=None): @@ -50,99 +133,23 @@ class DialogMethods(UserMethods): Yields: Instances of `telethon.tl.custom.dialog.Dialog`. """ - limit = float('inf') if limit is None else int(limit) - if limit == 0: - if not _total: - return - # Special case, get a single dialog and determine count - dialogs = await self(functions.messages.GetDialogsRequest( - offset_date=offset_date, - offset_id=offset_id, - offset_peer=offset_peer, - limit=1, - hash=0 - )) - _total[0] = getattr(dialogs, 'count', len(dialogs.dialogs)) - return - - seen = set() - req = functions.messages.GetDialogsRequest( + return _DialogsIter( + self, + limit, offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, - limit=0, - hash=0 + ignore_migrated=ignore_migrated ) - while len(seen) < limit: - req.limit = min(limit - len(seen), 100) - r = await self(req) - - if _total: - _total[0] = getattr(r, 'count', len(r.dialogs)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = {} - for m in r.messages: - m._finish_init(self, entities, None) - messages[m.id] = m - - # Happens when there are pinned dialogs - if len(r.dialogs) > limit: - r.dialogs = r.dialogs[:limit] - - for d in r.dialogs: - if offset_date: - date = getattr(messages.get( - d.top_message, None), 'date', None) - - if not date or date.timestamp() > offset_date.timestamp(): - continue - - peer_id = utils.get_peer_id(d.peer) - if peer_id not in seen: - seen.add(peer_id) - cd = custom.Dialog(self, d, entities, messages) - if cd.dialog.pts: - self._channel_pts[cd.id] = cd.dialog.pts - - if not ignore_migrated or getattr( - cd.entity, 'migrated_to', None) is None: - await yield_(cd) - - if len(r.dialogs) < req.limit\ - or not isinstance(r, types.messages.DialogsSlice): - # Less than we requested means we reached the end, or - # we didn't get a DialogsSlice which means we got all. - break - - req.offset_date = r.messages[-1].date - req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)] - if req.offset_id == r.messages[-1].id: - # In some very rare cases this will get stuck in an infinite - # loop, where the offsets will get reused over and over. If - # the new offset is the same as the one before, break already. - break - - req.offset_id = r.messages[-1].id - req.exclude_pinned = True async def get_dialogs(self, *args, **kwargs): """ Same as `iter_dialogs`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - dialogs = helpers.TotalList() - async for x in self.iter_dialogs(*args, **kwargs): - dialogs.append(x) - dialogs.total = total[0] - return dialogs + return await self.iter_dialogs(*args, **kwargs).collect() - @async_generator - async def iter_drafts(self): + def iter_drafts(self): """ Iterator over all open draft messages. @@ -151,18 +158,14 @@ class DialogMethods(UserMethods): to change the message or `telethon.tl.custom.draft.Draft.delete` among other things. """ - r = await self(functions.messages.GetAllDraftsRequest()) - for update in r.updates: - await yield_(custom.Draft._from_update(self, update)) + # TODO Passing a limit here makes no sense + return _DraftsIter(self, None) async def get_drafts(self): """ Same as :meth:`iter_drafts`, but returns a list instead. """ - result = [] - async for x in self.iter_drafts(): - result.append(x) - return result + return await self.iter_drafts().collect() def conversation( self, entity, diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2ad8d4b2..e87d80a0 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -1,14 +1,282 @@ -import asyncio import itertools -import time - -from async_generator import async_generator, yield_ from .messageparse import MessageParseMethods from .uploads import UploadMethods from .buttons import ButtonMethods from .. import helpers, utils, errors from ..tl import types, functions +from ..requestiter import RequestIter + + +# TODO Maybe RequestIter could rather have the update offset here? +# Maybe init should return the request to be used and it be +# called automatically? And another method to just process it. +class _MessagesIter(RequestIter): + """ + Common factor for all requests that need to iterate over messages. + """ + async def _init( + self, entity, offset_id, min_id, max_id, from_user, + batch_size, offset_date, add_offset, filter, search + ): + # Note that entity being ``None`` will perform a global search. + if entity: + self.entity = await self.client.get_input_entity(entity) + else: + self.entity = None + if self.reverse: + raise ValueError('Cannot reverse global search') + + # Telegram doesn't like min_id/max_id. If these IDs are low enough + # (starting from last_id - 100), the request will return nothing. + # + # We can emulate their behaviour locally by setting offset = max_id + # and simply stopping once we hit a message with ID <= min_id. + if self.reverse: + offset_id = max(offset_id, min_id) + if offset_id and max_id: + if max_id - offset_id <= 1: + raise StopAsyncIteration + + if not max_id: + max_id = float('inf') + else: + offset_id = max(offset_id, max_id) + if offset_id and min_id: + if offset_id - min_id <= 1: + raise StopAsyncIteration + + if self.reverse: + if offset_id: + offset_id += 1 + else: + offset_id = 1 + + if from_user: + from_user = await self.client.get_input_entity(from_user) + if not isinstance(from_user, ( + types.InputPeerUser, types.InputPeerSelf)): + from_user = None # Ignore from_user unless it's a user + + if from_user: + self.from_id = await self.client.get_peer_id(from_user) + else: + self.from_id = None + + if not self.entity: + self.request = functions.messages.SearchGlobalRequest( + q=search or '', + offset_date=offset_date, + offset_peer=types.InputPeerEmpty(), + offset_id=offset_id, + limit=1 + ) + elif search is not None or filter or from_user: + if filter is None: + filter = types.InputMessagesFilterEmpty() + + # Telegram completely ignores `from_id` in private chats + if isinstance( + self.entity, (types.InputPeerUser, types.InputPeerSelf)): + # Don't bother sending `from_user` (it's ignored anyway), + # but keep `from_id` defined above to check it locally. + from_user = None + else: + # Do send `from_user` to do the filtering server-side, + # and set `from_id` to None to avoid checking it locally. + self.from_id = None + + self.request = functions.messages.SearchRequest( + peer=self.entity, + q=search or '', + filter=filter() if isinstance(filter, type) else filter, + min_date=None, + max_date=offset_date, + offset_id=offset_id, + add_offset=add_offset, + limit=0, # Search actually returns 0 items if we ask it to + max_id=0, + min_id=0, + hash=0, + from_id=from_user + ) + else: + self.request = functions.messages.GetHistoryRequest( + peer=self.entity, + limit=1, + offset_date=offset_date, + offset_id=offset_id, + min_id=0, + max_id=0, + add_offset=add_offset, + hash=0 + ) + + if self.limit == 0: + # No messages, but we still need to know the total message count + result = await self.client(self.request) + if isinstance(result, types.messages.MessagesNotModified): + self.total = result.count + else: + self.total = getattr(result, 'count', len(result.messages)) + raise StopAsyncIteration + + if self.wait_time is None: + self.wait_time = 1 if self.limit > 3000 else 0 + + # Telegram has a hard limit of 100. + # We don't need to fetch 100 if the limit is less. + self.batch_size = min(max(batch_size, 1), min(100, self.limit)) + + # When going in reverse we need an offset of `-limit`, but we + # also want to respect what the user passed, so add them together. + if self.reverse: + self.request.add_offset -= self.batch_size + + self.add_offset = add_offset + self.max_id = max_id + self.min_id = min_id + self.last_id = 0 if self.reverse else float('inf') + + async def _load_next_chunk(self): + self.request.limit = min(self.left, self.batch_size) + if self.reverse and self.request.limit != self.batch_size: + # Remember that we need -limit when going in reverse + self.request.add_offset = self.add_offset - self.request.limit + + r = await self.client(self.request) + self.total = getattr(r, 'count', len(r.messages)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = reversed(r.messages) if self.reverse else r.messages + for message in messages: + if (isinstance(message, types.MessageEmpty) + or self.from_id and message.from_id != self.from_id): + continue + + if not self._message_in_range(message): + return True + + # There has been reports that on bad connections this method + # was returning duplicated IDs sometimes. Using ``last_id`` + # is an attempt to avoid these duplicates, since the message + # IDs are returned in descending order (or asc if reverse). + self.last_id = message.id + message._finish_init(self.client, entities, self.entity) + self.buffer.append(message) + + if len(r.messages) < self.request.limit: + return True + + # Get the last message that's not empty (in some rare cases + # it can happen that the last message is :tl:`MessageEmpty`) + if self.buffer: + self._update_offset(self.buffer[-1]) + else: + # There are some cases where all the messages we get start + # being empty. This can happen on migrated mega-groups if + # the history was cleared, and we're using search. Telegram + # acts incredibly weird sometimes. Messages are returned but + # only "empty", not their contents. If this is the case we + # should just give up since there won't be any new Message. + return True + + def _message_in_range(self, message): + """ + Determine whether the given message is in the range or + it should be ignored (and avoid loading more chunks). + """ + # No entity means message IDs between chats may vary + if self.entity: + if self.reverse: + if message.id <= self.last_id or message.id >= self.max_id: + return False + else: + if message.id >= self.last_id or message.id <= self.min_id: + return False + + return True + + def _update_offset(self, last_message): + """ + After making the request, update its offset with the last message. + """ + self.request.offset_id = last_message.id + if self.reverse: + # We want to skip the one we already have + self.request.offset_id += 1 + + if isinstance(self.request, functions.messages.SearchRequest): + # Unlike getHistory and searchGlobal that use *offset* date, + # this is *max* date. This means that doing a search in reverse + # will break it. Since it's not really needed once we're going + # (only for the first request), it's safe to just clear it off. + self.request.max_date = None + else: + # getHistory and searchGlobal call it offset_date + self.request.offset_date = last_message.date + + if isinstance(self.request, functions.messages.SearchGlobalRequest): + self.request.offset_peer = last_message.input_chat + + +class _IDsIter(RequestIter): + async def _init(self, entity, ids): + # TODO We never actually split IDs in chunks, but maybe we should + if not utils.is_list_like(ids): + self.ids = [ids] + elif not ids: + raise StopAsyncIteration + elif self.reverse: + self.ids = list(reversed(ids)) + else: + self.ids = ids + + if entity: + entity = await self.client.get_input_entity(entity) + + self.total = len(ids) + + from_id = None # By default, no need to validate from_id + if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): + try: + r = await self.client( + functions.channels.GetMessagesRequest(entity, ids)) + except errors.MessageIdsEmptyError: + # All IDs were invalid, use a dummy result + r = types.messages.MessagesNotModified(len(ids)) + else: + r = await self.client(functions.messages.GetMessagesRequest(ids)) + if entity: + from_id = await self.client.get_peer_id(entity) + + if isinstance(r, types.messages.MessagesNotModified): + self.buffer.extend(None for _ in ids) + return + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + # Telegram seems to return the messages in the order in which + # we asked them for, so we don't need to check it ourselves, + # unless some messages were invalid in which case Telegram + # may decide to not send them at all. + # + # The passed message IDs may not belong to the desired entity + # since the user can enter arbitrary numbers which can belong to + # arbitrary chats. Validate these unless ``from_id is None``. + for message in r.messages: + if isinstance(message, types.MessageEmpty) or ( + from_id and message.chat_id != from_id): + self.buffer.append(None) + else: + message._finish_init(self.client, entities, entity) + self.buffer.append(message) + + async def _load_next_chunk(self): + return True # no next chunk, all done in init class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): @@ -17,8 +285,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # region Message retrieval - @async_generator - async def iter_messages( + def iter_messages( self, entity, limit=None, *, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0, search=None, filter=None, from_user=None, batch_size=100, wait_time=None, ids=None, @@ -133,208 +400,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): an higher limit, so you're free to set the ``batch_size`` that you think may be good. """ - # Note that entity being ``None`` is intended to get messages by - # ID under no specific chat, and also to request a global search. - if entity: - entity = await self.get_input_entity(entity) - if ids: - if not utils.is_list_like(ids): - ids = (ids,) - if reverse: - ids = list(reversed(ids)) - async for x in self._iter_ids(entity, ids, total=_total): - await yield_(x) - return + if ids is not None: + return _IDsIter(self, limit, entity=entity, ids=ids) - # Telegram doesn't like min_id/max_id. If these IDs are low enough - # (starting from last_id - 100), the request will return nothing. - # - # We can emulate their behaviour locally by setting offset = max_id - # and simply stopping once we hit a message with ID <= min_id. - if reverse: - offset_id = max(offset_id, min_id) - if offset_id and max_id: - if max_id - offset_id <= 1: - return - - if not max_id: - max_id = float('inf') - else: - offset_id = max(offset_id, max_id) - if offset_id and min_id: - if offset_id - min_id <= 1: - return - - if reverse: - if offset_id: - offset_id += 1 - else: - offset_id = 1 - - if from_user: - from_user = await self.get_input_entity(from_user) - if not isinstance(from_user, ( - types.InputPeerUser, types.InputPeerSelf)): - from_user = None # Ignore from_user unless it's a user - - from_id = (await self.get_peer_id(from_user)) if from_user else None - - limit = float('inf') if limit is None else int(limit) - if not entity: - if reverse: - raise ValueError('Cannot reverse global search') - - reverse = None - request = functions.messages.SearchGlobalRequest( - q=search or '', - offset_date=offset_date, - offset_peer=types.InputPeerEmpty(), - offset_id=offset_id, - limit=1 - ) - elif search is not None or filter or from_user: - if filter is None: - filter = types.InputMessagesFilterEmpty() - - # Telegram completely ignores `from_id` in private chats - if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)): - # Don't bother sending `from_user` (it's ignored anyway), - # but keep `from_id` defined above to check it locally. - from_user = None - else: - # Do send `from_user` to do the filtering server-side, - # and set `from_id` to None to avoid checking it locally. - from_id = None - - request = functions.messages.SearchRequest( - peer=entity, - q=search or '', - filter=filter() if isinstance(filter, type) else filter, - min_date=None, - max_date=offset_date, - offset_id=offset_id, - add_offset=add_offset, - limit=0, # Search actually returns 0 items if we ask it to - max_id=0, - min_id=0, - hash=0, - from_id=from_user - ) - else: - request = functions.messages.GetHistoryRequest( - peer=entity, - limit=1, - offset_date=offset_date, - offset_id=offset_id, - min_id=0, - max_id=0, - add_offset=add_offset, - hash=0 - ) - - if limit == 0: - if not _total: - return - # No messages, but we still need to know the total message count - result = await self(request) - if isinstance(result, types.messages.MessagesNotModified): - _total[0] = result.count - else: - _total[0] = getattr(result, 'count', len(result.messages)) - return - - if wait_time is None: - wait_time = 1 if limit > 3000 else 0 - - have = 0 - last_id = 0 if reverse else float('inf') - - # Telegram has a hard limit of 100. - # We don't need to fetch 100 if the limit is less. - batch_size = min(max(batch_size, 1), min(100, limit)) - - # When going in reverse we need an offset of `-limit`, but we - # also want to respect what the user passed, so add them together. - if reverse: - request.add_offset -= batch_size - - while have < limit: - start = time.time() - - request.limit = min(limit - have, batch_size) - if reverse and request.limit != batch_size: - # Remember that we need -limit when going in reverse - request.add_offset = add_offset - request.limit - - r = await self(request) - if _total: - _total[0] = getattr(r, 'count', len(r.messages)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = reversed(r.messages) if reverse else r.messages - for message in messages: - if (isinstance(message, types.MessageEmpty) - or from_id and message.from_id != from_id): - continue - - if reverse is None: - pass - elif reverse: - if message.id <= last_id or message.id >= max_id: - return - else: - if message.id >= last_id or message.id <= min_id: - return - - # There has been reports that on bad connections this method - # was returning duplicated IDs sometimes. Using ``last_id`` - # is an attempt to avoid these duplicates, since the message - # IDs are returned in descending order (or asc if reverse). - last_id = message.id - - message._finish_init(self, entities, entity) - await yield_(message) - have += 1 - - if len(r.messages) < request.limit: - break - - # Find the first message that's not empty (in some rare cases - # it can happen that the last message is :tl:`MessageEmpty`) - last_message = None - messages = r.messages if reverse else reversed(r.messages) - for m in messages: - if not isinstance(m, types.MessageEmpty): - last_message = m - break - - if last_message is None: - # There are some cases where all the messages we get start - # being empty. This can happen on migrated mega-groups if - # the history was cleared, and we're using search. Telegram - # acts incredibly weird sometimes. Messages are returned but - # only "empty", not their contents. If this is the case we - # should just give up since there won't be any new Message. - break - else: - request.offset_id = last_message.id - if isinstance(request, functions.messages.SearchRequest): - request.max_date = last_message.date - else: - # getHistory and searchGlobal call it offset_date - request.offset_date = last_message.date - - if isinstance(request, functions.messages.SearchGlobalRequest): - request.offset_peer = last_message.input_chat - elif reverse: - # We want to skip the one we already have - request.offset_id += 1 - - await asyncio.sleep( - max(wait_time - (time.time() - start), 0), loop=self._loop) + return _MessagesIter( + client=self, + reverse=reverse, + wait_time=wait_time, + limit=limit, + entity=entity, + offset_id=offset_id, + min_id=min_id, + max_id=max_id, + from_user=from_user, + batch_size=batch_size, + offset_date=offset_date, + add_offset=add_offset, + filter=filter, + search=search + ) async def get_messages(self, *args, **kwargs): """ @@ -353,23 +438,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): a single `Message ` will be returned for convenience instead of a list. """ - total = [0] - kwargs['_total'] = total if len(args) == 1 and 'limit' not in kwargs: if 'min_id' in kwargs and 'max_id' in kwargs: kwargs['limit'] = None else: kwargs['limit'] = 1 - msgs = helpers.TotalList() - async for x in self.iter_messages(*args, **kwargs): - msgs.append(x) - msgs.total = total[0] - if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']): - # Check for empty list to handle InputMessageReplyTo - return msgs[0] if msgs else None + it = self.iter_messages(*args, **kwargs) - return msgs + ids = kwargs.get('ids') + if ids and not utils.is_list_like(ids): + async for message in it: + return message + else: + # Iterator exhausted = empty, to handle InputMessageReplyTo + return None + + return await it.collect() # endregion @@ -799,52 +884,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods): # endregion # endregion - - # region Private methods - - @async_generator - async def _iter_ids(self, entity, ids, total): - """ - Special case for `iter_messages` when it should only fetch some IDs. - """ - if total: - total[0] = len(ids) - - from_id = None # By default, no need to validate from_id - if isinstance(entity, (types.InputChannel, types.InputPeerChannel)): - try: - r = await self( - functions.channels.GetMessagesRequest(entity, ids)) - except errors.MessageIdsEmptyError: - # All IDs were invalid, use a dummy result - r = types.messages.MessagesNotModified(len(ids)) - else: - r = await self(functions.messages.GetMessagesRequest(ids)) - if entity: - from_id = utils.get_peer_id(entity) - - if isinstance(r, types.messages.MessagesNotModified): - for _ in ids: - await yield_(None) - return - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - # Telegram seems to return the messages in the order in which - # we asked them for, so we don't need to check it ourselves, - # unless some messages were invalid in which case Telegram - # may decide to not send them at all. - # - # The passed message IDs may not belong to the desired entity - # since the user can enter arbitrary numbers which can belong to - # arbitrary chats. Validate these unless ``from_id is None``. - for message in r.messages: - if isinstance(message, types.MessageEmpty) or ( - from_id and message.chat_id != from_id): - await yield_(None) - else: - message._finish_init(self, entities, entity) - await yield_(message) - - # endregion diff --git a/telethon/requestiter.py b/telethon/requestiter.py new file mode 100644 index 00000000..111d4509 --- /dev/null +++ b/telethon/requestiter.py @@ -0,0 +1,130 @@ +import abc +import asyncio +import time + +from . import helpers + + +# TODO There are two types of iterators for requests. +# One has a limit of items to retrieve, and the +# other has a list that must be called in chunks. +# Make classes for both here so it's easy to use. +class RequestIter(abc.ABC): + """ + Helper class to deal with requests that need offsets to iterate. + + It has some facilities, such as automatically sleeping a desired + amount of time between requests if needed (but not more). + + Can be used synchronously if the event loop is not running and + as an asynchronous iterator otherwise. + + `limit` is the total amount of items that the iterator should return. + This is handled on this base class, and will be always ``>= 0``. + + `left` will be reset every time the iterator is used and will indicate + the amount of items that should be emitted left, so that subclasses can + be more efficient and fetch only as many items as they need. + + Iterators may be used with ``reversed``, and their `reverse` flag will + be set to ``True`` if that's the case. Note that if this flag is set, + `buffer` should be filled in reverse too. + """ + def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs): + self.client = client + self.reverse = reverse + self.wait_time = wait_time + self.kwargs = kwargs + self.limit = max(float('inf') if limit is None else limit, 0) + self.left = None + self.buffer = None + self.index = None + self.total = None + self.last_load = None + + async def _init(self, **kwargs): + """ + Called when asynchronous initialization is necessary. All keyword + arguments passed to `__init__` will be forwarded here, and it's + preferable to use named arguments in the subclasses without defaults + to avoid forgetting or misspelling any of them. + + This method may ``raise StopAsyncIteration`` if it cannot continue. + + This method may actually fill the initial buffer if it needs to. + """ + + async def __anext__(self): + if self.buffer is None: + self.buffer = [] + await self._init(**self.kwargs) + + if self.left <= 0: # <= 0 because subclasses may change it + raise StopAsyncIteration + + if self.index == len(self.buffer): + # asyncio will handle times <= 0 to sleep 0 seconds + if self.wait_time: + await asyncio.sleep( + self.wait_time - (time.time() - self.last_load), + loop=self.client.loop + ) + self.last_load = time.time() + + self.index = 0 + self.buffer = [] + if await self._load_next_chunk(): + self.left = len(self.buffer) + + if not self.buffer: + raise StopAsyncIteration + + result = self.buffer[self.index] + self.left -= 1 + self.index += 1 + return result + + def __aiter__(self): + self.buffer = None + self.index = 0 + self.last_load = 0 + self.left = self.limit + return self + + def __iter__(self): + if self.client.loop.is_running(): + raise RuntimeError( + 'You must use "async for" if the event loop ' + 'is running (i.e. you are inside an "async def")' + ) + + return self.__aiter__() + + async def collect(self): + """ + Create a `self` iterator and collect it into a `TotalList` + (a normal list with a `.total` attribute). + """ + result = helpers.TotalList() + async for message in self: + result.append(message) + + result.total = self.total + return result + + @abc.abstractmethod + async def _load_next_chunk(self): + """ + Called when the next chunk is necessary. + + It should extend the `buffer` with new items. + + It should return ``True`` if it's the last chunk, + after which moment the method won't be called again + during the same iteration. + """ + raise NotImplementedError + + def __reversed__(self): + self.reverse = not self.reverse + return self # __aiter__ will be called after, too diff --git a/telethon/sync.py b/telethon/sync.py index 925b2046..c5f0f623 100644 --- a/telethon/sync.py +++ b/telethon/sync.py @@ -14,8 +14,6 @@ import asyncio import functools import inspect -from async_generator import isasyncgenfunction - from .client.telegramclient import TelegramClient from .tl.custom import ( Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation @@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter from .tl.custom.sendergetter import SenderGetter -class _SyncGen: - def __init__(self, gen): - self.gen = gen - - def __iter__(self): - return self - - def __next__(self): - try: - return asyncio.get_event_loop() \ - .run_until_complete(self.gen.__anext__()) - except StopAsyncIteration: - raise StopIteration from None - - -def _syncify_wrap(t, method_name, gen): +def _syncify_wrap(t, method_name): method = getattr(t, method_name) @functools.wraps(method) @@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen): loop = asyncio.get_event_loop() if loop.is_running(): return coro - elif gen: - return _SyncGen(coro) else: return loop.run_until_complete(coro) @@ -64,13 +45,14 @@ def syncify(*types): into synchronous, which return either the coroutine or the result based on whether ``asyncio's`` event loop is running. """ + # Our asynchronous generators all are `RequestIter`, which already + # provide a synchronous iterator variant, so we don't need to worry + # about asyncgenfunction's here. for t in types: for name in dir(t): if not name.startswith('_') or name == '__call__': if inspect.iscoroutinefunction(getattr(t, name)): - _syncify_wrap(t, name, gen=False) - elif isasyncgenfunction(getattr(t, name)): - _syncify_wrap(t, name, gen=True) + _syncify_wrap(t, name) syncify(TelegramClient, Draft, Dialog, MessageButton,