Factor out clearing items from pending in conversations

This should prevent bugs and ease reasoning, since
now everything is removed from a single place.
This commit is contained in:
Lonami Exo 2019-05-27 14:23:42 +02:00
parent 0b41454b01
commit 1354bf68a8

View File

@ -138,7 +138,7 @@ class Conversation(ChatGetter):
lambda x, y: x.reply_to_msg_id == y lambda x, y: x.reply_to_msg_id == y
) )
async def _get_message( def _get_message(
self, target_message, indices, pending, timeout, condition): self, target_message, indices, pending, timeout, condition):
""" """
Gets the next desired message under the desired condition. Gets the next desired message under the desired condition.
@ -192,11 +192,12 @@ class Conversation(ChatGetter):
return future return future
# Otherwise the next incoming response will be the one to use # Otherwise the next incoming response will be the one to use
#
# Note how we fill "pending" before giving control back to the
# event loop through "await". We want to register it as soon as
# possible, since any other task switch may arrive with the result.
pending[target_id] = future pending[target_id] = future
try: return self._get_result(future, start_time, timeout, pending, target_id)
return await self._get_result(future, start_time, timeout)
finally:
pending.pop(target_id, None)
async def get_edit(self, message=None, *, timeout=None): async def get_edit(self, message=None, *, timeout=None):
""" """
@ -222,12 +223,9 @@ class Conversation(ChatGetter):
return earliest_edit return earliest_edit
# Otherwise the next incoming response will be the one to use # Otherwise the next incoming response will be the one to use
future = asyncio.Future(loop=self._client.loop) future = self._client.loop.create_future()
self._pending_edits[target_id] = future self._pending_edits[target_id] = future
try: return await self._get_result(future, start_time, timeout, self._pending_edits, target_id)
return await self._get_result(future, start_time, timeout)
finally:
self._pending_edits.pop(target_id, None)
async def wait_read(self, message=None, *, timeout=None): async def wait_read(self, message=None, *, timeout=None):
""" """
@ -246,10 +244,7 @@ class Conversation(ChatGetter):
return return
self._pending_reads[target_id] = future self._pending_reads[target_id] = future
try: return await self._get_result(future, start_time, timeout, self._pending_reads, target_id)
return await self._get_result(future, start_time, timeout)
finally:
self._pending_reads.pop(target_id, None)
async def wait_event(self, event, *, timeout=None): async def wait_event(self, event, *, timeout=None):
""" """
@ -284,20 +279,9 @@ class Conversation(ChatGetter):
counter = Conversation._custom_counter counter = Conversation._custom_counter
Conversation._custom_counter += 1 Conversation._custom_counter += 1
future = asyncio.Future(loop=self._client.loop) future = self._client.loop.create_future()
# We need the `async def` here because we want to block on the future
# from `_get_result` by using `await` on it. If we returned the future
# immediately we would `del` from `_custom` too early.
async def result():
try:
return await self._get_result(future, start_time, timeout)
finally:
del self._custom[counter]
self._custom[counter] = (event, future) self._custom[counter] = (event, future)
return await result() return await self._get_result(future, start_time, timeout, self._custom, counter)
async def _check_custom(self, built): async def _check_custom(self, built):
for i, (ev, fut) in self._custom.items(): for i, (ev, fut) in self._custom.items():
@ -317,32 +301,23 @@ class Conversation(ChatGetter):
self._incoming.append(response) self._incoming.append(response)
found = [] # Note: we don't remove from pending here, that's done on get result
for msg_id in self._pending_responses: for msg_id, future in self._pending_responses.items():
found.append(msg_id)
self._response_indices[msg_id] = len(self._incoming) self._response_indices[msg_id] = len(self._incoming)
future.set_result(response)
for msg_id in found: for msg_id, future in self._pending_replies.items():
self._pending_responses.pop(msg_id).set_result(response)
found.clear()
for msg_id in self._pending_replies:
if msg_id == response.reply_to_msg_id: if msg_id == response.reply_to_msg_id:
found.append(msg_id)
self._reply_indices[msg_id] = len(self._incoming) self._reply_indices[msg_id] = len(self._incoming)
future.set_result(response)
for msg_id in found:
self._pending_replies.pop(msg_id).set_result(response)
def _on_edit(self, message): def _on_edit(self, message):
message = message.message message = message.message
if message.chat_id != self.chat_id or message.out: if message.chat_id != self.chat_id or message.out:
return return
found = [] for msg_id, future in self._pending_edits.items():
for msg_id, pending in self._pending_edits.items():
if msg_id < message.id: if msg_id < message.id:
found.append(msg_id)
edit_ts = message.edit_date.timestamp() edit_ts = message.edit_date.timestamp()
# We compare <= because edit_ts resolution is always to # We compare <= because edit_ts resolution is always to
@ -353,8 +328,7 @@ class Conversation(ChatGetter):
else: else:
self._edit_dates[msg_id] = message.edit_date.timestamp() self._edit_dates[msg_id] = message.edit_date.timestamp()
for msg_id in found: future.set_result(message)
self._pending_edits.pop(msg_id).set_result(message)
def _on_read(self, event): def _on_read(self, event):
if event.chat_id != self.chat_id or event.inbox: if event.chat_id != self.chat_id or event.inbox:
@ -379,7 +353,7 @@ class Conversation(ChatGetter):
else: else:
raise ValueError('No message was sent previously') raise ValueError('No message was sent previously')
def _get_result(self, future, start_time, timeout): async def _get_result(self, future, start_time, timeout, pending, target_id):
due = self._total_due due = self._total_due
if timeout is None: if timeout is None:
timeout = self._timeout timeout = self._timeout
@ -387,11 +361,14 @@ class Conversation(ChatGetter):
if timeout is not None: if timeout is not None:
due = min(due, start_time + timeout) due = min(due, start_time + timeout)
return asyncio.wait_for( try:
future, return await asyncio.wait_for(
timeout=None if due == float('inf') else due - time.time(), future,
loop=self._client.loop timeout=None if due == float('inf') else due - time.time(),
) loop=self._client.loop
)
finally:
del pending[target_id]
def _cancel_all(self, exception=None): def _cancel_all(self, exception=None):
for pending in itertools.chain( for pending in itertools.chain(