diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index 58a28d6f..9a209ec2 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -103,15 +103,27 @@ class _DialogsIter(RequestIter): class _DraftsIter(RequestIter): - async def _init(self, **kwargs): - r = await self.client(functions.messages.GetAllDraftsRequest()) + async def _init(self, entities, **kwargs): + if not entities: + r = await self.client(functions.messages.GetAllDraftsRequest()) + items = r.updates + else: + peers = [] + for entity in entities: + peers.append(types.InputDialogPeer( + await self.client.get_input_entity(entity))) + + r = await self.client(functions.messages.GetPeerDialogsRequest(peers)) + items = r.dialogs # TODO Maybe there should be a helper method for this? entities = {utils.get_peer_id(x): x for x in itertools.chain(r.users, r.chats)} - self.buffer.extend(custom.Draft._from_update(self.client, u, entities) - for u in r.updates) + self.buffer.extend( + custom.Draft(self.client, entities[utils.get_peer_id(d.peer)], d.draft) + for d in items + ) async def _load_next_chunk(self): return [] @@ -236,12 +248,20 @@ class DialogMethods: """ return await self.iter_dialogs(*args, **kwargs).collect() - def iter_drafts(self: 'TelegramClient') -> _DraftsIter: + def iter_drafts( + self: 'TelegramClient', + entity: 'hints.EntitiesLike' = None + ) -> _DraftsIter: """ - Iterator over all open draft messages. + Iterator over draft messages. The order is unspecified. + Arguments + entity (`hints.EntitiesLike`, optional): + The entity or entities for which to fetch the draft messages. + If left unspecified, all draft messages will be returned. + Yields Instances of `Draft `. @@ -251,11 +271,21 @@ class DialogMethods: # Clear all drafts for draft in client.get_drafts(): draft.delete() - """ - # TODO Passing a limit here makes no sense - return _DraftsIter(self, None) - async def get_drafts(self: 'TelegramClient') -> 'hints.TotalList': + # Getting the drafts with 'bot1' and 'bot2' + for draft in client.iter_drafts(['bot1', 'bot2']): + print(draft.text) + """ + if entity and not utils.is_list_like(entity): + entity = (entity,) + + # TODO Passing a limit here makes no sense + return _DraftsIter(self, None, entities=entity) + + async def get_drafts( + self: 'TelegramClient', + entity: 'hints.EntitiesLike' = None + ) -> 'hints.TotalList': """ Same as `iter_drafts()`, but returns a list instead. @@ -265,8 +295,16 @@ class DialogMethods: # Get drafts, print the text of the first drafts = client.get_drafts() print(drafts[0].text) + + # Get the draft in your chat + draft = client.get_drafts('me') + print(drafts.text) """ - return await self.iter_drafts().collect() + items = await self.iter_drafts(entity).collect() + if not entity or utils.is_list_like(entity): + return items + else: + return items[0] async def edit_folder( self: 'TelegramClient', diff --git a/telethon/tl/custom/dialog.py b/telethon/tl/custom/dialog.py index c434e12e..b71b4c3f 100644 --- a/telethon/tl/custom/dialog.py +++ b/telethon/tl/custom/dialog.py @@ -87,7 +87,7 @@ class Dialog: self.unread_count = dialog.unread_count self.unread_mentions_count = dialog.unread_mentions_count - self.draft = Draft._from_dialog(client, self) + self.draft = Draft(client, self.entity, self.dialog.draft) self.is_user = isinstance(self.entity, types.User) self.is_group = ( diff --git a/telethon/tl/custom/draft.py b/telethon/tl/custom/draft.py index 84289c58..a44986b7 100644 --- a/telethon/tl/custom/draft.py +++ b/telethon/tl/custom/draft.py @@ -2,10 +2,10 @@ import datetime from .. import TLObject from ..functions.messages import SaveDraftRequest -from ..types import UpdateDraftMessage, DraftMessage +from ..types import DraftMessage from ...errors import RPCError from ...extensions import markdown -from ...utils import get_peer_id, get_input_peer +from ...utils import get_input_peer, get_peer class Draft: @@ -24,9 +24,9 @@ class Draft: reply_to_msg_id (`int`): The message ID that the draft will reply to. """ - def __init__(self, client, peer, draft, entity): + def __init__(self, client, entity, draft): self._client = client - self._peer = peer + self._peer = get_peer(entity) self._entity = entity self._input_entity = get_input_peer(entity) if entity else None @@ -39,17 +39,6 @@ class Draft: self.link_preview = not draft.no_webpage self.reply_to_msg_id = draft.reply_to_msg_id - @classmethod - def _from_dialog(cls, client, dialog): - return cls(client=client, peer=dialog.dialog.peer, - draft=dialog.dialog.draft, entity=dialog.entity) - - @classmethod - def _from_update(cls, client, update, entities): - assert isinstance(update, UpdateDraftMessage) - return cls(client=client, peer=update.peer, draft=update.draft, - entity=entities.get(get_peer_id(update.peer))) - @property def entity(self): """