Factor out clearing items from pending in conversations

This should prevent bugs and ease reasoning, since
now everything is removed from a single place.
This commit is contained in:
Lonami Exo 2019-05-27 14:23:42 +02:00
parent 0b41454b01
commit 1354bf68a8

View File

@ -138,7 +138,7 @@ class Conversation(ChatGetter):
lambda x, y: x.reply_to_msg_id == y
)
async def _get_message(
def _get_message(
self, target_message, indices, pending, timeout, condition):
"""
Gets the next desired message under the desired condition.
@ -192,11 +192,12 @@ class Conversation(ChatGetter):
return future
# Otherwise the next incoming response will be the one to use
#
# Note how we fill "pending" before giving control back to the
# event loop through "await". We want to register it as soon as
# possible, since any other task switch may arrive with the result.
pending[target_id] = future
try:
return await self._get_result(future, start_time, timeout)
finally:
pending.pop(target_id, None)
return self._get_result(future, start_time, timeout, pending, target_id)
async def get_edit(self, message=None, *, timeout=None):
"""
@ -222,12 +223,9 @@ class Conversation(ChatGetter):
return earliest_edit
# Otherwise the next incoming response will be the one to use
future = asyncio.Future(loop=self._client.loop)
future = self._client.loop.create_future()
self._pending_edits[target_id] = future
try:
return await self._get_result(future, start_time, timeout)
finally:
self._pending_edits.pop(target_id, None)
return await self._get_result(future, start_time, timeout, self._pending_edits, target_id)
async def wait_read(self, message=None, *, timeout=None):
"""
@ -246,10 +244,7 @@ class Conversation(ChatGetter):
return
self._pending_reads[target_id] = future
try:
return await self._get_result(future, start_time, timeout)
finally:
self._pending_reads.pop(target_id, None)
return await self._get_result(future, start_time, timeout, self._pending_reads, target_id)
async def wait_event(self, event, *, timeout=None):
"""
@ -284,20 +279,9 @@ class Conversation(ChatGetter):
counter = Conversation._custom_counter
Conversation._custom_counter += 1
future = asyncio.Future(loop=self._client.loop)
# We need the `async def` here because we want to block on the future
# from `_get_result` by using `await` on it. If we returned the future
# immediately we would `del` from `_custom` too early.
async def result():
try:
return await self._get_result(future, start_time, timeout)
finally:
del self._custom[counter]
future = self._client.loop.create_future()
self._custom[counter] = (event, future)
return await result()
return await self._get_result(future, start_time, timeout, self._custom, counter)
async def _check_custom(self, built):
for i, (ev, fut) in self._custom.items():
@ -317,32 +301,23 @@ class Conversation(ChatGetter):
self._incoming.append(response)
found = []
for msg_id in self._pending_responses:
found.append(msg_id)
# Note: we don't remove from pending here, that's done on get result
for msg_id, future in self._pending_responses.items():
self._response_indices[msg_id] = len(self._incoming)
future.set_result(response)
for msg_id in found:
self._pending_responses.pop(msg_id).set_result(response)
found.clear()
for msg_id in self._pending_replies:
for msg_id, future in self._pending_replies.items():
if msg_id == response.reply_to_msg_id:
found.append(msg_id)
self._reply_indices[msg_id] = len(self._incoming)
for msg_id in found:
self._pending_replies.pop(msg_id).set_result(response)
future.set_result(response)
def _on_edit(self, message):
message = message.message
if message.chat_id != self.chat_id or message.out:
return
found = []
for msg_id, pending in self._pending_edits.items():
for msg_id, future in self._pending_edits.items():
if msg_id < message.id:
found.append(msg_id)
edit_ts = message.edit_date.timestamp()
# We compare <= because edit_ts resolution is always to
@ -353,8 +328,7 @@ class Conversation(ChatGetter):
else:
self._edit_dates[msg_id] = message.edit_date.timestamp()
for msg_id in found:
self._pending_edits.pop(msg_id).set_result(message)
future.set_result(message)
def _on_read(self, event):
if event.chat_id != self.chat_id or event.inbox:
@ -379,7 +353,7 @@ class Conversation(ChatGetter):
else:
raise ValueError('No message was sent previously')
def _get_result(self, future, start_time, timeout):
async def _get_result(self, future, start_time, timeout, pending, target_id):
due = self._total_due
if timeout is None:
timeout = self._timeout
@ -387,11 +361,14 @@ class Conversation(ChatGetter):
if timeout is not None:
due = min(due, start_time + timeout)
return asyncio.wait_for(
future,
timeout=None if due == float('inf') else due - time.time(),
loop=self._client.loop
)
try:
return await asyncio.wait_for(
future,
timeout=None if due == float('inf') else due - time.time(),
loop=self._client.loop
)
finally:
del pending[target_id]
def _cancel_all(self, exception=None):
for pending in itertools.chain(