From 6d30a3831698e5f8e8d70b6b83348e73008152cc Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 19 Oct 2018 10:45:59 +0200 Subject: [PATCH] Let Connection._disconnected be a proper Future This means that awaiting on disconnect will properly raise errors, allowing to differentiate clean disconnects from faulty ones. --- telethon/network/connection/connection.py | 42 +++++++++++++---------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index fd6d3aa5..c435d48e 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -26,9 +26,8 @@ class Connection(abc.ABC): self._proxy = proxy self._reader = None self._writer = None - self._disconnected = asyncio.Event(loop=loop) - self._disconnected.set() - self._disconnected_future = None + self._disconnected = self._loop.create_future() + self._disconnected.set_result(None) self._send_task = None self._recv_task = None self._send_queue = asyncio.Queue(1) @@ -76,8 +75,7 @@ class Connection(abc.ABC): self._reader, self._writer = \ await asyncio.open_connection(sock=s, loop=self._loop) - self._disconnected.clear() - self._disconnected_future = None + self._disconnected = self._loop.create_future() self._send_task = self._loop.create_task(self._send_loop()) self._recv_task = self._loop.create_task(self._recv_loop()) @@ -86,7 +84,13 @@ class Connection(abc.ABC): Disconnects from the server, and clears pending outgoing and incoming messages. """ - self._disconnected.set() + self._disconnect(error=None) + + def _disconnect(self, error): + if error: + self._disconnected.set_exception(error) + else: + self._disconnected.set_result(None) while not self._send_queue.empty(): self._send_queue.get_nowait() @@ -105,10 +109,7 @@ class Connection(abc.ABC): @property def disconnected(self): - if not self._disconnected_future: - self._disconnected_future = \ - self._loop.create_task(self._disconnected.wait()) - return self._disconnected_future + return asyncio.shield(self._disconnected, loop=self._loop) def clone(self): """ @@ -122,7 +123,7 @@ class Connection(abc.ABC): This method returns a coroutine. """ - if self._disconnected.is_set(): + if self._disconnected.done(): raise ConnectionError('Not connected') return self._send_queue.put(data) @@ -133,7 +134,7 @@ class Connection(abc.ABC): This method returns a coroutine. """ - if self._disconnected.is_set(): + if self._disconnected.done(): raise ConnectionError('Not connected') result = await self._recv_queue.get() @@ -147,33 +148,36 @@ class Connection(abc.ABC): This loop is constantly popping items off the queue to send them. """ try: - while not self._disconnected.is_set(): + while not self._disconnected.done(): self._send(await self._send_queue.get()) await self._writer.drain() except asyncio.CancelledError: pass except Exception: - logging.exception('Unhandled exception in the send loop') - self.disconnect() + msg = 'Unexpected exception in the send loop' + logging.exception(msg) + self._disconnect(ConnectionError(msg)) async def _recv_loop(self): """ This loop is constantly putting items on the queue as they're read. """ try: - while not self._disconnected.is_set(): + while not self._disconnected.done(): data = await self._recv() await self._recv_queue.put(data) except asyncio.CancelledError: pass except Exception as e: if isinstance(e, asyncio.IncompleteReadError): - logging.info('The server closed the connection') + msg = 'The server closed the connection' + logging.info(msg) else: - logging.exception('Unhandled exception in the receive loop') + msg = 'Unexpected exception in the receive loop' + logging.exception(msg) await self._recv_queue.put(None) - self.disconnect() + self._disconnect(ConnectionError(msg)) @abc.abstractmethod def _send(self, data):