Let Connection._disconnected be a proper Future

This means that awaiting on disconnect will properly raise errors,
allowing to differentiate clean disconnects from faulty ones.
This commit is contained in:
Lonami Exo 2018-10-19 10:45:59 +02:00
parent 542d0f539b
commit 6d30a38316

View File

@ -26,9 +26,8 @@ class Connection(abc.ABC):
self._proxy = proxy
self._reader = None
self._writer = None
self._disconnected = asyncio.Event(loop=loop)
self._disconnected.set()
self._disconnected_future = None
self._disconnected = self._loop.create_future()
self._disconnected.set_result(None)
self._send_task = None
self._recv_task = None
self._send_queue = asyncio.Queue(1)
@ -76,8 +75,7 @@ class Connection(abc.ABC):
self._reader, self._writer = \
await asyncio.open_connection(sock=s, loop=self._loop)
self._disconnected.clear()
self._disconnected_future = None
self._disconnected = self._loop.create_future()
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
@ -86,7 +84,13 @@ class Connection(abc.ABC):
Disconnects from the server, and clears
pending outgoing and incoming messages.
"""
self._disconnected.set()
self._disconnect(error=None)
def _disconnect(self, error):
if error:
self._disconnected.set_exception(error)
else:
self._disconnected.set_result(None)
while not self._send_queue.empty():
self._send_queue.get_nowait()
@ -105,10 +109,7 @@ class Connection(abc.ABC):
@property
def disconnected(self):
if not self._disconnected_future:
self._disconnected_future = \
self._loop.create_task(self._disconnected.wait())
return self._disconnected_future
return asyncio.shield(self._disconnected, loop=self._loop)
def clone(self):
"""
@ -122,7 +123,7 @@ class Connection(abc.ABC):
This method returns a coroutine.
"""
if self._disconnected.is_set():
if self._disconnected.done():
raise ConnectionError('Not connected')
return self._send_queue.put(data)
@ -133,7 +134,7 @@ class Connection(abc.ABC):
This method returns a coroutine.
"""
if self._disconnected.is_set():
if self._disconnected.done():
raise ConnectionError('Not connected')
result = await self._recv_queue.get()
@ -147,33 +148,36 @@ class Connection(abc.ABC):
This loop is constantly popping items off the queue to send them.
"""
try:
while not self._disconnected.is_set():
while not self._disconnected.done():
self._send(await self._send_queue.get())
await self._writer.drain()
except asyncio.CancelledError:
pass
except Exception:
logging.exception('Unhandled exception in the send loop')
self.disconnect()
msg = 'Unexpected exception in the send loop'
logging.exception(msg)
self._disconnect(ConnectionError(msg))
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.
"""
try:
while not self._disconnected.is_set():
while not self._disconnected.done():
data = await self._recv()
await self._recv_queue.put(data)
except asyncio.CancelledError:
pass
except Exception as e:
if isinstance(e, asyncio.IncompleteReadError):
logging.info('The server closed the connection')
msg = 'The server closed the connection'
logging.info(msg)
else:
logging.exception('Unhandled exception in the receive loop')
msg = 'Unexpected exception in the receive loop'
logging.exception(msg)
await self._recv_queue.put(None)
self.disconnect()
self._disconnect(ConnectionError(msg))
@abc.abstractmethod
def _send(self, data):