Let Connection._disconnected be a proper Future

This means that awaiting on disconnect will properly raise errors,
allowing to differentiate clean disconnects from faulty ones.
This commit is contained in:
Lonami Exo 2018-10-19 10:45:59 +02:00
parent 542d0f539b
commit 6d30a38316

View File

@ -26,9 +26,8 @@ class Connection(abc.ABC):
self._proxy = proxy self._proxy = proxy
self._reader = None self._reader = None
self._writer = None self._writer = None
self._disconnected = asyncio.Event(loop=loop) self._disconnected = self._loop.create_future()
self._disconnected.set() self._disconnected.set_result(None)
self._disconnected_future = None
self._send_task = None self._send_task = None
self._recv_task = None self._recv_task = None
self._send_queue = asyncio.Queue(1) self._send_queue = asyncio.Queue(1)
@ -76,8 +75,7 @@ class Connection(abc.ABC):
self._reader, self._writer = \ self._reader, self._writer = \
await asyncio.open_connection(sock=s, loop=self._loop) await asyncio.open_connection(sock=s, loop=self._loop)
self._disconnected.clear() self._disconnected = self._loop.create_future()
self._disconnected_future = None
self._send_task = self._loop.create_task(self._send_loop()) self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop()) self._recv_task = self._loop.create_task(self._recv_loop())
@ -86,7 +84,13 @@ class Connection(abc.ABC):
Disconnects from the server, and clears Disconnects from the server, and clears
pending outgoing and incoming messages. pending outgoing and incoming messages.
""" """
self._disconnected.set() self._disconnect(error=None)
def _disconnect(self, error):
if error:
self._disconnected.set_exception(error)
else:
self._disconnected.set_result(None)
while not self._send_queue.empty(): while not self._send_queue.empty():
self._send_queue.get_nowait() self._send_queue.get_nowait()
@ -105,10 +109,7 @@ class Connection(abc.ABC):
@property @property
def disconnected(self): def disconnected(self):
if not self._disconnected_future: return asyncio.shield(self._disconnected, loop=self._loop)
self._disconnected_future = \
self._loop.create_task(self._disconnected.wait())
return self._disconnected_future
def clone(self): def clone(self):
""" """
@ -122,7 +123,7 @@ class Connection(abc.ABC):
This method returns a coroutine. This method returns a coroutine.
""" """
if self._disconnected.is_set(): if self._disconnected.done():
raise ConnectionError('Not connected') raise ConnectionError('Not connected')
return self._send_queue.put(data) return self._send_queue.put(data)
@ -133,7 +134,7 @@ class Connection(abc.ABC):
This method returns a coroutine. This method returns a coroutine.
""" """
if self._disconnected.is_set(): if self._disconnected.done():
raise ConnectionError('Not connected') raise ConnectionError('Not connected')
result = await self._recv_queue.get() result = await self._recv_queue.get()
@ -147,33 +148,36 @@ class Connection(abc.ABC):
This loop is constantly popping items off the queue to send them. This loop is constantly popping items off the queue to send them.
""" """
try: try:
while not self._disconnected.is_set(): while not self._disconnected.done():
self._send(await self._send_queue.get()) self._send(await self._send_queue.get())
await self._writer.drain() await self._writer.drain()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception: except Exception:
logging.exception('Unhandled exception in the send loop') msg = 'Unexpected exception in the send loop'
self.disconnect() logging.exception(msg)
self._disconnect(ConnectionError(msg))
async def _recv_loop(self): async def _recv_loop(self):
""" """
This loop is constantly putting items on the queue as they're read. This loop is constantly putting items on the queue as they're read.
""" """
try: try:
while not self._disconnected.is_set(): while not self._disconnected.done():
data = await self._recv() data = await self._recv()
await self._recv_queue.put(data) await self._recv_queue.put(data)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
if isinstance(e, asyncio.IncompleteReadError): if isinstance(e, asyncio.IncompleteReadError):
logging.info('The server closed the connection') msg = 'The server closed the connection'
logging.info(msg)
else: else:
logging.exception('Unhandled exception in the receive loop') msg = 'Unexpected exception in the receive loop'
logging.exception(msg)
await self._recv_queue.put(None) await self._recv_queue.put(None)
self.disconnect() self._disconnect(ConnectionError(msg))
@abc.abstractmethod @abc.abstractmethod
def _send(self, data): def _send(self, data):