Cancel tasks on reconnect instead of awaiting them

This prevents us from locking forever on any task that doesn't
rely on cancellation tokens, in this case, Connection.recv()'s
_recv_queue.get() would never complete after the server closed
the connection.

Additionally, working with cancellation tokens in asyncio is
somewhat annoying to do.

Last but not least removing the Connection._disconnected future
avoids the need to use its state (if an exception was set it
should be retrieved) to prevent asyncio from complaining, which
it was before.
This commit is contained in:
Lonami Exo 2018-10-21 16:20:05 +02:00
parent f2e77f4030
commit 7dece209a0
3 changed files with 14 additions and 36 deletions

View File

@ -38,7 +38,7 @@ class MessagePacker:
self._deque.extend(states)
self._ready.set()
async def get(self, cancellation):
async def get(self):
"""
Returns (batch, data) if one or more items could be retrieved.
@ -47,19 +47,7 @@ class MessagePacker:
"""
if not self._deque:
self._ready.clear()
ready = self._loop.create_task(self._ready.wait())
try:
done, pending = await asyncio.wait(
[ready, cancellation],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
except asyncio.CancelledError:
done = [cancellation]
if cancellation in done:
ready.cancel()
return None, None
await self._ready.wait()
buffer = io.BytesIO()
batch = []

View File

@ -28,8 +28,7 @@ class Connection(abc.ABC):
self._proxy = proxy
self._reader = None
self._writer = None
self._disconnected = self._loop.create_future()
self._disconnected.set_result(None)
self._connected = False
self._send_task = None
self._recv_task = None
self._send_queue = asyncio.Queue(1)
@ -77,7 +76,7 @@ class Connection(abc.ABC):
self._reader, self._writer = \
await asyncio.open_connection(sock=s, loop=self._loop)
self._disconnected = self._loop.create_future()
self._connected = True
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
@ -89,11 +88,7 @@ class Connection(abc.ABC):
self._disconnect(error=None)
def _disconnect(self, error):
if not self._disconnected.done():
if error:
self._disconnected.set_exception(error)
else:
self._disconnected.set_result(None)
self._connected = False
while not self._send_queue.empty():
self._send_queue.get_nowait()
@ -110,10 +105,6 @@ class Connection(abc.ABC):
if self._writer:
self._writer.close()
@property
def disconnected(self):
return asyncio.shield(self._disconnected, loop=self._loop)
def clone(self):
"""
Creates a clone of the connection.
@ -126,7 +117,7 @@ class Connection(abc.ABC):
This method returns a coroutine.
"""
if self._disconnected.done():
if not self._connected:
raise ConnectionError('Not connected')
return self._send_queue.put(data)
@ -137,7 +128,7 @@ class Connection(abc.ABC):
This method returns a coroutine.
"""
if self._disconnected.done():
if not self._connected:
raise ConnectionError('Not connected')
result = await self._recv_queue.get()
@ -151,7 +142,7 @@ class Connection(abc.ABC):
This loop is constantly popping items off the queue to send them.
"""
try:
while not self._disconnected.done():
while self._connected:
self._send(await self._send_queue.get())
await self._writer.drain()
except asyncio.CancelledError:
@ -166,7 +157,7 @@ class Connection(abc.ABC):
This loop is constantly putting items on the queue as they're read.
"""
try:
while not self._disconnected.done():
while self._connected:
data = await self._recv()
await self._recv_queue.put(data)
except asyncio.CancelledError:

View File

@ -279,11 +279,11 @@ class MTProtoSender:
__log__.debug('Closing current connection...')
self._connection.disconnect()
__log__.debug('Awaiting for the send loop before reconnecting...')
await self._send_loop_handle
__log__.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
__log__.debug('Awaiting for the receive loop before reconnecting...')
await self._recv_loop_handle
__log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
self._reconnecting = False
@ -334,8 +334,7 @@ class MTProtoSender:
# This means that while it's not empty we can wait for
# more messages to be added to the send queue.
try:
batch, data = await self._send_queue.get(
self._connection.disconnected)
batch, data = await self._send_queue.get()
except asyncio.CancelledError:
return