From 7dece209a046559f85d29a48b3ecc7f8a9543d52 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sun, 21 Oct 2018 16:20:05 +0200 Subject: [PATCH] Cancel tasks on reconnect instead of awaiting them This prevents us from locking forever on any task that doesn't rely on cancellation tokens, in this case, Connection.recv()'s _recv_queue.get() would never complete after the server closed the connection. Additionally, working with cancellation tokens in asyncio is somewhat annoying to do. Last but not least removing the Connection._disconnected future avoids the need to use its state (if an exception was set it should be retrieved) to prevent asyncio from complaining, which it was before. --- telethon/extensions/messagepacker.py | 16 ++-------------- telethon/network/connection/connection.py | 23 +++++++---------------- telethon/network/mtprotosender.py | 11 +++++------ 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/telethon/extensions/messagepacker.py b/telethon/extensions/messagepacker.py index 3944afe6..fe7b1e6b 100644 --- a/telethon/extensions/messagepacker.py +++ b/telethon/extensions/messagepacker.py @@ -38,7 +38,7 @@ class MessagePacker: self._deque.extend(states) self._ready.set() - async def get(self, cancellation): + async def get(self): """ Returns (batch, data) if one or more items could be retrieved. @@ -47,19 +47,7 @@ class MessagePacker: """ if not self._deque: self._ready.clear() - ready = self._loop.create_task(self._ready.wait()) - try: - done, pending = await asyncio.wait( - [ready, cancellation], - return_when=asyncio.FIRST_COMPLETED, - loop=self._loop - ) - except asyncio.CancelledError: - done = [cancellation] - - if cancellation in done: - ready.cancel() - return None, None + await self._ready.wait() buffer = io.BytesIO() batch = [] diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index e67e3267..625635e3 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -28,8 +28,7 @@ class Connection(abc.ABC): self._proxy = proxy self._reader = None self._writer = None - self._disconnected = self._loop.create_future() - self._disconnected.set_result(None) + self._connected = False self._send_task = None self._recv_task = None self._send_queue = asyncio.Queue(1) @@ -77,7 +76,7 @@ class Connection(abc.ABC): self._reader, self._writer = \ await asyncio.open_connection(sock=s, loop=self._loop) - self._disconnected = self._loop.create_future() + self._connected = True self._send_task = self._loop.create_task(self._send_loop()) self._recv_task = self._loop.create_task(self._recv_loop()) @@ -89,11 +88,7 @@ class Connection(abc.ABC): self._disconnect(error=None) def _disconnect(self, error): - if not self._disconnected.done(): - if error: - self._disconnected.set_exception(error) - else: - self._disconnected.set_result(None) + self._connected = False while not self._send_queue.empty(): self._send_queue.get_nowait() @@ -110,10 +105,6 @@ class Connection(abc.ABC): if self._writer: self._writer.close() - @property - def disconnected(self): - return asyncio.shield(self._disconnected, loop=self._loop) - def clone(self): """ Creates a clone of the connection. @@ -126,7 +117,7 @@ class Connection(abc.ABC): This method returns a coroutine. """ - if self._disconnected.done(): + if not self._connected: raise ConnectionError('Not connected') return self._send_queue.put(data) @@ -137,7 +128,7 @@ class Connection(abc.ABC): This method returns a coroutine. """ - if self._disconnected.done(): + if not self._connected: raise ConnectionError('Not connected') result = await self._recv_queue.get() @@ -151,7 +142,7 @@ class Connection(abc.ABC): This loop is constantly popping items off the queue to send them. """ try: - while not self._disconnected.done(): + while self._connected: self._send(await self._send_queue.get()) await self._writer.drain() except asyncio.CancelledError: @@ -166,7 +157,7 @@ class Connection(abc.ABC): This loop is constantly putting items on the queue as they're read. """ try: - while not self._disconnected.done(): + while self._connected: data = await self._recv() await self._recv_queue.put(data) except asyncio.CancelledError: diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 462aa22a..b4347586 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -279,11 +279,11 @@ class MTProtoSender: __log__.debug('Closing current connection...') self._connection.disconnect() - __log__.debug('Awaiting for the send loop before reconnecting...') - await self._send_loop_handle + __log__.debug('Cancelling the send loop...') + self._send_loop_handle.cancel() - __log__.debug('Awaiting for the receive loop before reconnecting...') - await self._recv_loop_handle + __log__.debug('Cancelling the receive loop...') + self._recv_loop_handle.cancel() self._reconnecting = False @@ -334,8 +334,7 @@ class MTProtoSender: # This means that while it's not empty we can wait for # more messages to be added to the send queue. try: - batch, data = await self._send_queue.get( - self._connection.disconnected) + batch, data = await self._send_queue.get() except asyncio.CancelledError: return