Improve getting edits in a Conversation

This commit is contained in:
Lonami Exo 2018-08-04 15:35:51 +02:00
parent 5f73482d29
commit 396b1a4177

View File

@ -56,7 +56,7 @@ class Conversation(ChatGetter):
else:
self._reply_indices = {}
self._edit_indices = {}
self._edit_dates = {}
async def send_message(self, *args, **kwargs):
"""
@ -99,16 +99,6 @@ class Conversation(ChatGetter):
lambda x, y: x.reply_to_msg_id == y
)
async def get_edit(self, message=None, *, timeout=None):
"""
Awaits for an edit after the last message to arrive.
The arguments are the same as those for `get_response`.
"""
return await self._get_message(
message, self._reply_indices, self._pending_edits, timeout,
lambda x, y: x.edit_date
)
async def _get_message(
self, target_message, indices, pending, timeout, condition):
"""
@ -135,7 +125,6 @@ class Conversation(ChatGetter):
message is a valid response.
"""
now = time.time()
future = asyncio.Future()
target_id = self._get_message_id(target_message)
# If there is no last-chosen ID, make sure to pick one *after*
@ -157,6 +146,7 @@ class Conversation(ChatGetter):
return incoming
# Otherwise the next incoming response will be the one to use
future = asyncio.Future()
pending[target_id] = future
done, pending = await asyncio.wait(
[future, self._sleep(now, timeout)],
@ -170,6 +160,42 @@ class Conversation(ChatGetter):
else:
return future.result()
async def get_edit(self, message=None, *, timeout=None):
"""
Awaits for an edit after the last message to arrive.
The arguments are the same as those for `get_response`.
"""
now = time.time()
target_id = self._get_message_id(message)
target_date = self._edit_dates.get(target_id, 0)
earliest_edit = min(
(x for x in self._incoming
if x.id > target_id and x.date.timestamp() > target_date),
key=lambda x: x.date,
default=None
)
if earliest_edit and earliest_edit.date.timestamp() > target_date:
self._edit_dates[target_id] = earliest_edit.timestamp()
return earliest_edit
# Otherwise the next incoming response will be the one to use
future = asyncio.Future()
self._pending_edits[target_id] = future
done, pending = await asyncio.wait(
[future, self._sleep(now, timeout)],
return_when=asyncio.FIRST_COMPLETED
)
if future in pending:
for future in pending:
future.cancel()
raise asyncio.TimeoutError()
else:
return future.result()
async def wait_read(self, message=None, *, timeout=None):
"""
Awaits for the sent message to be read. Note that receiving
@ -213,21 +239,23 @@ class Conversation(ChatGetter):
return
self._incoming.append(response)
for msg_id, pending in self._pending_responses.items():
found = []
for msg_id in self._pending_responses:
found.append(msg_id)
self._response_indices[msg_id] = len(self._incoming)
pending.set_result(response)
self._pending_responses.clear()
for msg_id in found:
self._pending_responses.pop(msg_id).set_result(response)
remove_replies = []
for msg_id, pending in self._pending_replies.items():
found.clear()
for msg_id in self._pending_replies:
if msg_id == response.reply_to_msg_id:
remove_replies.append(msg_id)
found.append(msg_id)
self._reply_indices[msg_id] = len(self._incoming)
pending.set_result(response)
for to_remove in remove_replies:
del self._reply_indices[to_remove]
for msg_id in found:
self._pending_replies.pop(msg_id).set_result(response)
# TODO Edits are different since they work by date not indices
# That is, we need to scan all incoming messages and detect if
@ -236,20 +264,16 @@ class Conversation(ChatGetter):
if message.chat_id != self.chat_id or message.out:
return
for i, msg in enumerate(self._incoming):
if msg.id == message.id:
self._incoming[i] = msg
break
remove_edits = []
for msg_id, pending in self._pending_replies.items():
found = []
for msg_id, pending in self._pending_edits.items():
if msg_id == message.id:
remove_edits.append(msg_id)
self._edit_indices[msg_id] = len(self._incoming)
pending.set_result(message)
found.append(msg_id)
self._edit_dates[msg_id] = message.date.timestamp()
for to_remove in remove_edits:
del self._edit_indices[to_remove]
for msg_id in found:
self._pending_edits.pop(msg_id).set_result(message)
# TODO Support custom events in a comfortable way
def _on_read(self, event):
if event.chat_id != self.chat_id or event.inbox: