diff --git a/telethon/client/chats.py b/telethon/client/chats.py index d0b0d2e9..56031f19 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -1,5 +1,7 @@ from collections import UserList +from async_generator import async_generator, yield_ + from .users import UserMethods from .. import utils from ..tl import types, functions @@ -9,6 +11,7 @@ class ChatMethods(UserMethods): # region Public methods + @async_generator async def iter_participants( self, entity, limit=None, search='', filter=None, aggressive=False, _total=None): @@ -130,7 +133,7 @@ class ChatMethods(UserMethods): seen.add(participant.user_id) user = users[participant.user_id] user.participant = participant - yield user + await yield_(user) if len(seen) >= limit: return @@ -159,7 +162,7 @@ class ChatMethods(UserMethods): else: user = users[participant.user_id] user.participant = participant - yield user + await yield_(user) else: if _total: _total[0] = 1 @@ -167,7 +170,7 @@ class ChatMethods(UserMethods): user = await self.get_entity(entity) if filter_entity(user): user.participant = None - yield user + await yield_(user) async def get_participants(self, *args, **kwargs): """ diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index a28432d9..d42d1c49 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -1,14 +1,17 @@ import itertools from collections import UserList +from async_generator import async_generator, yield_ + from .users import UserMethods -from ..tl import types, functions, custom from .. import utils +from ..tl import types, functions, custom class DialogMethods(UserMethods): # region Public methods + @async_generator async def iter_dialogs( self, limit=None, offset_date=None, offset_id=0, offset_peer=types.InputPeerEmpty(), _total=None): @@ -80,7 +83,7 @@ class DialogMethods(UserMethods): peer_id = utils.get_peer_id(d.peer) if peer_id not in seen: seen.add(peer_id) - yield custom.Dialog(self, d, entities, messages) + await yield_(custom.Dialog(self, d, entities, messages)) if len(r.dialogs) < req.limit\ or not isinstance(r, types.messages.DialogsSlice): @@ -115,7 +118,7 @@ class DialogMethods(UserMethods): """ r = await self(functions.messages.GetAllDraftsRequest()) for update in r.updates: - yield custom.Draft._from_update(self, update) + await yield_(custom.Draft._from_update(self, update)) async def get_drafts(self): """ diff --git a/telethon/client/messages.py b/telethon/client/messages.py index 2e9d9bc1..13d98820 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -5,6 +5,8 @@ import time import warnings from collections import UserList +from async_generator import async_generator, yield_ + from .messageparse import MessageParseMethods from .uploads import UploadMethods from .. import utils @@ -19,6 +21,7 @@ class MessageMethods(UploadMethods, MessageParseMethods): # region Message retrieval + @async_generator async def iter_messages( self, entity, limit=None, offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0, search=None, filter=None, @@ -114,7 +117,7 @@ class MessageMethods(UploadMethods, MessageParseMethods): if not utils.is_list_like(ids): ids = (ids,) async for x in self._iter_ids(entity, ids, total=_total): - yield x + await yield_(x) return # Telegram doesn't like min_id/max_id. If these IDs are low enough @@ -202,7 +205,7 @@ class MessageMethods(UploadMethods, MessageParseMethods): # IDs are returned in descending order. last_id = message.id - yield custom.Message(self, message, entities, entity) + await yield_(custom.Message(self, message, entities, entity)) have += 1 if len(r.messages) < request.limit: @@ -620,6 +623,7 @@ class MessageMethods(UploadMethods, MessageParseMethods): # region Private methods + @async_generator async def _iter_ids(self, entity, ids, total): """ Special case for `iter_messages` when it should only fetch some IDs. @@ -634,7 +638,7 @@ class MessageMethods(UploadMethods, MessageParseMethods): if isinstance(r, types.messages.MessagesNotModified): for _ in ids: - yield None + await yield_(None) return entities = {utils.get_peer_id(x): x @@ -644,8 +648,8 @@ class MessageMethods(UploadMethods, MessageParseMethods): # we asked them for, so we don't need to check it ourselves. for message in r.messages: if isinstance(message, types.MessageEmpty): - yield None + await yield_(None) else: - yield custom.Message(self, message, entities, entity) + await yield_(custom.Message(self, message, entities, entity)) # endregion