Create a basic custom.Conversation

This commit is contained in:
Lonami Exo 2018-08-03 17:51:56 +02:00
parent 60c61181d9
commit 785ef7676f
5 changed files with 403 additions and 0 deletions

View File

@ -157,4 +157,80 @@ class DialogMethods(UserMethods):
result.append(x) result.append(x)
return result return result
def conversation(
self, entity,
*, timeout=None, total_timeout=60, max_messages=100,
replies_are_responses=True):
"""
Returns an iterator over the dialogs, yielding 'limit' at most.
Dialogs are the open "chats" or conversations with other people,
groups you have joined, or channels you are subscribed to.
Args:
entity (`entity`):
The entity with which a new conversation should be opened.
timeout (`int` | `float`, optional):
The default timeout *per action* to be used. You
can override this on each action. By default there
is no per-action time limit but there is still a
`total_timeout` for the entire conversation.
total_timeout (`int` | `float`, optional):
The total timeout to use for the whole conversation.
After these many seconds pass, subsequent actions
will result in ``asyncio.TimeoutError``.
max_messages (`int`, optional):
The maximum amount of messages this conversation will
remember. After these many messages arrive in the
specified chat, subsequent actions will result in
``ValueError``.
replies_are_responses (`bool`, optional):
Whether replies should be treated as responses or not.
If the setting is enabled, calls to `conv.get_response
<telethon.tl.custom.conversation.Conversation.get_response>`
and a subsequent call to `conv.get_reply
<telethon.tl.custom.conversation.Conversation.get_reply>`
will return different messages, otherwise they may return
the same message.
Consider the following scenario with one outgoing message,
1, and two incoming messages, the second one replying::
Hello! <1
2> (reply to 1) Hi!
3> (reply to 1) How are you?
And the following code:
.. code-block:: python
async with client.conversation(chat) as conv:
msg1 = await conv.send_message('Hello!')
msg2 = await conv.get_response()
msg3 = await conv.get_reply()
With the setting enabled, ``msg2`` will be ``'Hi!'`` and
``msg3`` be ``'How are you?'`` since replies are also
responses, and a response was already returned.
With the setting disabled, both ``msg2`` and ``msg3`` will
be ``'Hi!'`` since one is a response and also a reply.
Returns:
A `Conversation <telethon.tl.custom.conversation.Conversation>`.
"""
return custom.Conversation(
self,
entity,
timeout=timeout,
total_timeout=total_timeout,
max_messages=max_messages,
replies_are_responses=replies_are_responses
)
# endregion # endregion

View File

@ -272,6 +272,7 @@ class TelegramBaseClient(abc.ABC):
self._event_builders = [] self._event_builders = []
self._events_pending_resolve = [] self._events_pending_resolve = []
self._event_resolve_lock = asyncio.Lock() self._event_resolve_lock = asyncio.Lock()
self._conversations = {}
# Keep track of how many event builders there are for # Keep track of how many event builders there are for
# each type {type: count}. If there's at least one then # each type {type: count}. If there's at least one then

View File

@ -269,6 +269,20 @@ class UpdateMethods(UserMethods):
built = {builder: builder.build(update) built = {builder: builder.build(update)
for builder in self._event_builders_count} for builder in self._event_builders_count}
if self._conversations:
for ev_type in (events.NewMessage, events.MessageEdited,
events.MessageRead):
if ev_type not in built:
built[ev_type] = ev_type.build(update)
for conv in self._conversations.values():
if built[events.NewMessage]:
conv._on_new_message(built[events.NewMessage])
if built[events.MessageEdited]:
conv._on_edit(built[events.MessageEdited])
if built[events.MessageRead]:
conv._on_read(built[events.MessageRead])
for builder, callback in self._event_builders: for builder, callback in self._event_builders:
event = built[type(builder)] event = built[type(builder)]
if not event or not builder.filter(event): if not event or not builder.filter(event):

View File

@ -8,3 +8,4 @@ from .button import Button
from .inline import InlineBuilder from .inline import InlineBuilder
from .inlineresult import InlineResult from .inlineresult import InlineResult
from .inlineresults import InlineResults from .inlineresults import InlineResults
from .conversation import Conversation

View File

@ -0,0 +1,311 @@
import asyncio
import itertools
import time
from .chatgetter import ChatGetter
from ... import utils
class Conversation(ChatGetter):
"""
Represents a conversation inside an specific chat.
A conversation keeps track of new messages since it was
created until its exit and easily lets you query the
current state.
If you need a conversation across two or more chats,
you should use two conversations and synchronize them
as you better see fit.
"""
_id_counter = 0
def __init__(self, client, input_chat,
*, timeout, total_timeout, max_messages,
replies_are_responses):
self._id = Conversation._id_counter
Conversation._id_counter += 1
self._client = client
self._chat = None
self._input_chat = input_chat
self._chat_peer = None
self._broadcast = None
self._timeout = timeout
if total_timeout:
self._total_due = time.time() + total_timeout
else:
self._total_due = float('inf')
self._outgoing = set()
self._last_outgoing = 0
self._incoming = []
self._last_incoming = 0
self._max_incoming = max_messages
self._last_read = None
self._pending_responses = {}
self._pending_replies = {}
self._pending_edits = {}
self._pending_reads = {}
# The user is able to expect two responses for the same message.
# {desired message ID: next incoming index}
self._response_indices = {}
if replies_are_responses:
self._reply_indices = self._response_indices
else:
self._reply_indices = {}
self._edit_indices = {}
async def send_message(self, *args, **kwargs):
"""
Sends a message in the context of this conversation. Shorthand
for `telethon.client.messages.MessageMethods.send_message` with
``entity`` already set.
"""
message = await self._client.send_message(
self._input_chat, *args, **kwargs)
self._outgoing.add(message.id)
self._last_outgoing = message.id
return message
async def get_response(self, message=None, *, timeout=None):
"""
Awaits for a response to arrive.
Args:
message (:tl:`Message` | `int`, optional):
The message (or the message ID) for which a response
is expected. By default this is the last sent message.
timeout (`int` | `float`, optional):
If present, this `timeout` will override the
per-action timeout defined for the conversation.
"""
return await self._get_message(
message, self._response_indices, self._pending_responses, timeout,
lambda x, y: True
)
async def get_reply(self, message=None, *, timeout=None):
"""
Awaits for a reply (that is, a message being a reply) to arrive.
The arguments are the same as those for `get_response`.
"""
return await self._get_message(
message, self._reply_indices, self._pending_replies, timeout,
lambda x, y: x.reply_to_msg_id == y
)
async def get_edit(self, message=None, *, timeout=None):
"""
Awaits for an edit after the last message to arrive.
The arguments are the same as those for `get_response`.
"""
return await self._get_message(
message, self._reply_indices, self._pending_edits, timeout,
lambda x, y: x.edit_date
)
async def _get_message(
self, target_message, indices, pending, timeout, condition):
"""
Gets the next desired message under the desired condition.
Args:
target_message (`object`):
The target message for which we want to find another
response that applies based on `condition`.
indices (`dict`):
This dictionary remembers the last ID chosen for the
input `target_message`.
pending (`dict`):
This dictionary remembers {msg_id: Future} to be set
once `condition` is met.
timeout (`int`):
The timeout override to use for this operation.
condition (`callable`):
The condition callable that checks if an incoming
message is a valid response.
"""
now = time.time()
future = asyncio.Future()
target_id = self._get_message_id(target_message)
# If there is no last-chosen ID, make sure to pick one *after*
# the input message, since we don't want responses back in time
if target_id not in indices:
for i, incoming in self._incoming:
if incoming.id > target_id:
indices[target_id] = i
break
else:
indices[target_id] = 0
# If there are enough responses saved return the next one
last_idx = indices[target_id]
if last_idx < len(self._incoming):
incoming = self._incoming[last_idx]
if condition(incoming, target_id):
indices[target_id] += 1
return incoming
# Otherwise the next incoming response will be the one to use
pending[target_id] = future
done, pending = await asyncio.wait(
[future, self._sleep(now, timeout)],
return_when=asyncio.FIRST_COMPLETED
)
if future in pending:
for future in pending:
future.cancel()
raise asyncio.TimeoutError()
else:
return future.result()
async def wait_read(self, message=None, *, timeout=None):
"""
Awaits for the sent message to be read. Note that receiving
a response doesn't imply the message was read, and this action
will also trigger even without a response.
"""
now = time.time()
future = asyncio.Future()
target_id = self._get_message_id(message)
if self._last_read is None:
self._last_read = target_id - 1
if self._last_read >= target_id:
return
self._pending_reads[target_id] = future
done, pending = await asyncio.wait(
[future, self._sleep(now, timeout)],
return_when=asyncio.FIRST_COMPLETED
)
if future in pending:
for future in pending:
future.cancel()
raise asyncio.TimeoutError()
else:
return future.result()
def _on_new_message(self, response):
if response.chat_id != self.chat_id or response.out:
return
if len(self._incoming) == self._max_incoming:
too_many = ValueError('Too many incoming messages')
for pending in itertools.chain(
self._pending_responses.values(),
self._pending_replies.values(),
self._pending_edits):
pending.set_exception(too_many)
return
self._incoming.append(response)
for msg_id, pending in self._pending_responses.items():
self._response_indices[msg_id] = len(self._incoming)
pending.set_result(response)
self._pending_responses.clear()
remove_replies = []
for msg_id, pending in self._pending_replies.items():
if msg_id == response.reply_to_msg_id:
remove_replies.append(msg_id)
self._reply_indices[msg_id] = len(self._incoming)
pending.set_result(response)
for to_remove in remove_replies:
del self._reply_indices[to_remove]
# TODO Edits are different since they work by date not indices
# That is, we need to scan all incoming messages and detect if
# the last used edit date is different from the one we knew.
def _on_edit(self, message):
if message.chat_id != self.chat_id or message.out:
return
for i, msg in enumerate(self._incoming):
if msg.id == message.id:
self._incoming[i] = msg
break
remove_edits = []
for msg_id, pending in self._pending_replies.items():
if msg_id == message.id:
remove_edits.append(msg_id)
self._edit_indices[msg_id] = len(self._incoming)
pending.set_result(message)
for to_remove in remove_edits:
del self._edit_indices[to_remove]
def _on_read(self, event):
if event.chat_id != self.chat_id or event.inbox:
return
self._last_read = event.max_id
remove_reads = []
for msg_id, pending in self._pending_reads.items():
if msg_id >= self._last_read:
remove_reads.append(msg_id)
pending.set_result(True)
for to_remove in remove_reads:
del self._pending_reads[to_remove]
def _get_message_id(self, message):
if message:
return message if isinstance(message, int) else message.id
elif self._last_outgoing:
return self._last_outgoing
else:
raise ValueError('No message was sent previously')
async def _sleep(self, now, timeout):
due = self._total_due
if timeout is None:
timeout = self._timeout
if timeout is not None:
due = min(due, now + timeout)
try:
if due == float('inf'):
while True:
await asyncio.sleep(60)
elif due > now:
await asyncio.sleep(due - now)
except asyncio.CancelledError:
pass
async def __aenter__(self):
self._client._conversations[self._id] = self
self._input_chat = \
await self._client.get_input_entity(self._input_chat)
self._chat_peer = utils.get_peer(self._input_chat)
self._outgoing.clear()
self._last_outgoing = 0
self._incoming.clear()
self._last_incoming = 0
self._pending_responses.clear()
self._response_indices.clear()
return self
async def __aexit__(self, *args):
del self._client._conversations[self._id]