From 785ef7676f162d2aa89d0380afa4ad88e4524300 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 3 Aug 2018 17:51:56 +0200 Subject: [PATCH] Create a basic custom.Conversation --- telethon/client/dialogs.py | 76 +++++++ telethon/client/telegrambaseclient.py | 1 + telethon/client/updates.py | 14 ++ telethon/tl/custom/__init__.py | 1 + telethon/tl/custom/conversation.py | 311 ++++++++++++++++++++++++++ 5 files changed, 403 insertions(+) create mode 100644 telethon/tl/custom/conversation.py diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index fe92fa82..5d953fd5 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -157,4 +157,80 @@ class DialogMethods(UserMethods): result.append(x) return result + def conversation( + self, entity, + *, timeout=None, total_timeout=60, max_messages=100, + replies_are_responses=True): + """ + Returns an iterator over the dialogs, yielding 'limit' at most. + Dialogs are the open "chats" or conversations with other people, + groups you have joined, or channels you are subscribed to. + + Args: + entity (`entity`): + The entity with which a new conversation should be opened. + + timeout (`int` | `float`, optional): + The default timeout *per action* to be used. You + can override this on each action. By default there + is no per-action time limit but there is still a + `total_timeout` for the entire conversation. + + total_timeout (`int` | `float`, optional): + The total timeout to use for the whole conversation. + After these many seconds pass, subsequent actions + will result in ``asyncio.TimeoutError``. + + max_messages (`int`, optional): + The maximum amount of messages this conversation will + remember. After these many messages arrive in the + specified chat, subsequent actions will result in + ``ValueError``. + + replies_are_responses (`bool`, optional): + Whether replies should be treated as responses or not. + + If the setting is enabled, calls to `conv.get_response + ` + and a subsequent call to `conv.get_reply + ` + will return different messages, otherwise they may return + the same message. + + Consider the following scenario with one outgoing message, + 1, and two incoming messages, the second one replying:: + + Hello! <1 + 2> (reply to 1) Hi! + 3> (reply to 1) How are you? + + And the following code: + + .. code-block:: python + + async with client.conversation(chat) as conv: + msg1 = await conv.send_message('Hello!') + msg2 = await conv.get_response() + msg3 = await conv.get_reply() + + With the setting enabled, ``msg2`` will be ``'Hi!'`` and + ``msg3`` be ``'How are you?'`` since replies are also + responses, and a response was already returned. + + With the setting disabled, both ``msg2`` and ``msg3`` will + be ``'Hi!'`` since one is a response and also a reply. + + Returns: + A `Conversation `. + """ + return custom.Conversation( + self, + entity, + timeout=timeout, + total_timeout=total_timeout, + max_messages=max_messages, + replies_are_responses=replies_are_responses + + ) + # endregion diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index f261539d..00a499d8 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -272,6 +272,7 @@ class TelegramBaseClient(abc.ABC): self._event_builders = [] self._events_pending_resolve = [] self._event_resolve_lock = asyncio.Lock() + self._conversations = {} # Keep track of how many event builders there are for # each type {type: count}. If there's at least one then diff --git a/telethon/client/updates.py b/telethon/client/updates.py index f568cf27..408caf12 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -269,6 +269,20 @@ class UpdateMethods(UserMethods): built = {builder: builder.build(update) for builder in self._event_builders_count} + if self._conversations: + for ev_type in (events.NewMessage, events.MessageEdited, + events.MessageRead): + if ev_type not in built: + built[ev_type] = ev_type.build(update) + + for conv in self._conversations.values(): + if built[events.NewMessage]: + conv._on_new_message(built[events.NewMessage]) + if built[events.MessageEdited]: + conv._on_edit(built[events.MessageEdited]) + if built[events.MessageRead]: + conv._on_read(built[events.MessageRead]) + for builder, callback in self._event_builders: event = built[type(builder)] if not event or not builder.filter(event): diff --git a/telethon/tl/custom/__init__.py b/telethon/tl/custom/__init__.py index 7511c176..8db2d53f 100644 --- a/telethon/tl/custom/__init__.py +++ b/telethon/tl/custom/__init__.py @@ -8,3 +8,4 @@ from .button import Button from .inline import InlineBuilder from .inlineresult import InlineResult from .inlineresults import InlineResults +from .conversation import Conversation diff --git a/telethon/tl/custom/conversation.py b/telethon/tl/custom/conversation.py new file mode 100644 index 00000000..c5262bed --- /dev/null +++ b/telethon/tl/custom/conversation.py @@ -0,0 +1,311 @@ +import asyncio +import itertools +import time + +from .chatgetter import ChatGetter +from ... import utils + + +class Conversation(ChatGetter): + """ + Represents a conversation inside an specific chat. + + A conversation keeps track of new messages since it was + created until its exit and easily lets you query the + current state. + + If you need a conversation across two or more chats, + you should use two conversations and synchronize them + as you better see fit. + """ + _id_counter = 0 + + def __init__(self, client, input_chat, + *, timeout, total_timeout, max_messages, + replies_are_responses): + self._id = Conversation._id_counter + Conversation._id_counter += 1 + + self._client = client + self._chat = None + self._input_chat = input_chat + self._chat_peer = None + self._broadcast = None + + self._timeout = timeout + if total_timeout: + self._total_due = time.time() + total_timeout + else: + self._total_due = float('inf') + + self._outgoing = set() + self._last_outgoing = 0 + self._incoming = [] + self._last_incoming = 0 + self._max_incoming = max_messages + self._last_read = None + + self._pending_responses = {} + self._pending_replies = {} + self._pending_edits = {} + self._pending_reads = {} + + # The user is able to expect two responses for the same message. + # {desired message ID: next incoming index} + self._response_indices = {} + if replies_are_responses: + self._reply_indices = self._response_indices + else: + self._reply_indices = {} + + self._edit_indices = {} + + async def send_message(self, *args, **kwargs): + """ + Sends a message in the context of this conversation. Shorthand + for `telethon.client.messages.MessageMethods.send_message` with + ``entity`` already set. + """ + message = await self._client.send_message( + self._input_chat, *args, **kwargs) + + self._outgoing.add(message.id) + self._last_outgoing = message.id + return message + + async def get_response(self, message=None, *, timeout=None): + """ + Awaits for a response to arrive. + + Args: + message (:tl:`Message` | `int`, optional): + The message (or the message ID) for which a response + is expected. By default this is the last sent message. + + timeout (`int` | `float`, optional): + If present, this `timeout` will override the + per-action timeout defined for the conversation. + """ + return await self._get_message( + message, self._response_indices, self._pending_responses, timeout, + lambda x, y: True + ) + + async def get_reply(self, message=None, *, timeout=None): + """ + Awaits for a reply (that is, a message being a reply) to arrive. + The arguments are the same as those for `get_response`. + """ + return await self._get_message( + message, self._reply_indices, self._pending_replies, timeout, + lambda x, y: x.reply_to_msg_id == y + ) + + async def get_edit(self, message=None, *, timeout=None): + """ + Awaits for an edit after the last message to arrive. + The arguments are the same as those for `get_response`. + """ + return await self._get_message( + message, self._reply_indices, self._pending_edits, timeout, + lambda x, y: x.edit_date + ) + + async def _get_message( + self, target_message, indices, pending, timeout, condition): + """ + Gets the next desired message under the desired condition. + + Args: + target_message (`object`): + The target message for which we want to find another + response that applies based on `condition`. + + indices (`dict`): + This dictionary remembers the last ID chosen for the + input `target_message`. + + pending (`dict`): + This dictionary remembers {msg_id: Future} to be set + once `condition` is met. + + timeout (`int`): + The timeout override to use for this operation. + + condition (`callable`): + The condition callable that checks if an incoming + message is a valid response. + """ + now = time.time() + future = asyncio.Future() + target_id = self._get_message_id(target_message) + + # If there is no last-chosen ID, make sure to pick one *after* + # the input message, since we don't want responses back in time + if target_id not in indices: + for i, incoming in self._incoming: + if incoming.id > target_id: + indices[target_id] = i + break + else: + indices[target_id] = 0 + + # If there are enough responses saved return the next one + last_idx = indices[target_id] + if last_idx < len(self._incoming): + incoming = self._incoming[last_idx] + if condition(incoming, target_id): + indices[target_id] += 1 + return incoming + + # Otherwise the next incoming response will be the one to use + pending[target_id] = future + done, pending = await asyncio.wait( + [future, self._sleep(now, timeout)], + return_when=asyncio.FIRST_COMPLETED + ) + if future in pending: + for future in pending: + future.cancel() + + raise asyncio.TimeoutError() + else: + return future.result() + + async def wait_read(self, message=None, *, timeout=None): + """ + Awaits for the sent message to be read. Note that receiving + a response doesn't imply the message was read, and this action + will also trigger even without a response. + """ + now = time.time() + future = asyncio.Future() + target_id = self._get_message_id(message) + + if self._last_read is None: + self._last_read = target_id - 1 + + if self._last_read >= target_id: + return + + self._pending_reads[target_id] = future + done, pending = await asyncio.wait( + [future, self._sleep(now, timeout)], + return_when=asyncio.FIRST_COMPLETED + ) + if future in pending: + for future in pending: + future.cancel() + + raise asyncio.TimeoutError() + else: + return future.result() + + def _on_new_message(self, response): + if response.chat_id != self.chat_id or response.out: + return + + if len(self._incoming) == self._max_incoming: + too_many = ValueError('Too many incoming messages') + for pending in itertools.chain( + self._pending_responses.values(), + self._pending_replies.values(), + self._pending_edits): + pending.set_exception(too_many) + return + + self._incoming.append(response) + for msg_id, pending in self._pending_responses.items(): + self._response_indices[msg_id] = len(self._incoming) + pending.set_result(response) + + self._pending_responses.clear() + + remove_replies = [] + for msg_id, pending in self._pending_replies.items(): + if msg_id == response.reply_to_msg_id: + remove_replies.append(msg_id) + self._reply_indices[msg_id] = len(self._incoming) + pending.set_result(response) + + for to_remove in remove_replies: + del self._reply_indices[to_remove] + + # TODO Edits are different since they work by date not indices + # That is, we need to scan all incoming messages and detect if + # the last used edit date is different from the one we knew. + def _on_edit(self, message): + if message.chat_id != self.chat_id or message.out: + return + + for i, msg in enumerate(self._incoming): + if msg.id == message.id: + self._incoming[i] = msg + break + + remove_edits = [] + for msg_id, pending in self._pending_replies.items(): + if msg_id == message.id: + remove_edits.append(msg_id) + self._edit_indices[msg_id] = len(self._incoming) + pending.set_result(message) + + for to_remove in remove_edits: + del self._edit_indices[to_remove] + + def _on_read(self, event): + if event.chat_id != self.chat_id or event.inbox: + return + + self._last_read = event.max_id + + remove_reads = [] + for msg_id, pending in self._pending_reads.items(): + if msg_id >= self._last_read: + remove_reads.append(msg_id) + pending.set_result(True) + + for to_remove in remove_reads: + del self._pending_reads[to_remove] + + def _get_message_id(self, message): + if message: + return message if isinstance(message, int) else message.id + elif self._last_outgoing: + return self._last_outgoing + else: + raise ValueError('No message was sent previously') + + async def _sleep(self, now, timeout): + due = self._total_due + if timeout is None: + timeout = self._timeout + + if timeout is not None: + due = min(due, now + timeout) + + try: + if due == float('inf'): + while True: + await asyncio.sleep(60) + elif due > now: + await asyncio.sleep(due - now) + except asyncio.CancelledError: + pass + + async def __aenter__(self): + self._client._conversations[self._id] = self + self._input_chat = \ + await self._client.get_input_entity(self._input_chat) + + self._chat_peer = utils.get_peer(self._input_chat) + self._outgoing.clear() + self._last_outgoing = 0 + self._incoming.clear() + self._last_incoming = 0 + self._pending_responses.clear() + self._response_indices.clear() + return self + + async def __aexit__(self, *args): + del self._client._conversations[self._id]