mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-22 17:36:34 +03:00
Cancel tasks on reconnect instead of awaiting them
This prevents us from locking forever on any task that doesn't rely on cancellation tokens, in this case, Connection.recv()'s _recv_queue.get() would never complete after the server closed the connection. Additionally, working with cancellation tokens in asyncio is somewhat annoying to do. Last but not least removing the Connection._disconnected future avoids the need to use its state (if an exception was set it should be retrieved) to prevent asyncio from complaining, which it was before.
This commit is contained in:
parent
f2e77f4030
commit
7dece209a0
|
@ -38,7 +38,7 @@ class MessagePacker:
|
|||
self._deque.extend(states)
|
||||
self._ready.set()
|
||||
|
||||
async def get(self, cancellation):
|
||||
async def get(self):
|
||||
"""
|
||||
Returns (batch, data) if one or more items could be retrieved.
|
||||
|
||||
|
@ -47,19 +47,7 @@ class MessagePacker:
|
|||
"""
|
||||
if not self._deque:
|
||||
self._ready.clear()
|
||||
ready = self._loop.create_task(self._ready.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
[ready, cancellation],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
done = [cancellation]
|
||||
|
||||
if cancellation in done:
|
||||
ready.cancel()
|
||||
return None, None
|
||||
await self._ready.wait()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
batch = []
|
||||
|
|
|
@ -28,8 +28,7 @@ class Connection(abc.ABC):
|
|||
self._proxy = proxy
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._disconnected = self._loop.create_future()
|
||||
self._disconnected.set_result(None)
|
||||
self._connected = False
|
||||
self._send_task = None
|
||||
self._recv_task = None
|
||||
self._send_queue = asyncio.Queue(1)
|
||||
|
@ -77,7 +76,7 @@ class Connection(abc.ABC):
|
|||
self._reader, self._writer = \
|
||||
await asyncio.open_connection(sock=s, loop=self._loop)
|
||||
|
||||
self._disconnected = self._loop.create_future()
|
||||
self._connected = True
|
||||
self._send_task = self._loop.create_task(self._send_loop())
|
||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||
|
||||
|
@ -89,11 +88,7 @@ class Connection(abc.ABC):
|
|||
self._disconnect(error=None)
|
||||
|
||||
def _disconnect(self, error):
|
||||
if not self._disconnected.done():
|
||||
if error:
|
||||
self._disconnected.set_exception(error)
|
||||
else:
|
||||
self._disconnected.set_result(None)
|
||||
self._connected = False
|
||||
|
||||
while not self._send_queue.empty():
|
||||
self._send_queue.get_nowait()
|
||||
|
@ -110,10 +105,6 @@ class Connection(abc.ABC):
|
|||
if self._writer:
|
||||
self._writer.close()
|
||||
|
||||
@property
|
||||
def disconnected(self):
|
||||
return asyncio.shield(self._disconnected, loop=self._loop)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Creates a clone of the connection.
|
||||
|
@ -126,7 +117,7 @@ class Connection(abc.ABC):
|
|||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
if self._disconnected.done():
|
||||
if not self._connected:
|
||||
raise ConnectionError('Not connected')
|
||||
|
||||
return self._send_queue.put(data)
|
||||
|
@ -137,7 +128,7 @@ class Connection(abc.ABC):
|
|||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
if self._disconnected.done():
|
||||
if not self._connected:
|
||||
raise ConnectionError('Not connected')
|
||||
|
||||
result = await self._recv_queue.get()
|
||||
|
@ -151,7 +142,7 @@ class Connection(abc.ABC):
|
|||
This loop is constantly popping items off the queue to send them.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.done():
|
||||
while self._connected:
|
||||
self._send(await self._send_queue.get())
|
||||
await self._writer.drain()
|
||||
except asyncio.CancelledError:
|
||||
|
@ -166,7 +157,7 @@ class Connection(abc.ABC):
|
|||
This loop is constantly putting items on the queue as they're read.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.done():
|
||||
while self._connected:
|
||||
data = await self._recv()
|
||||
await self._recv_queue.put(data)
|
||||
except asyncio.CancelledError:
|
||||
|
|
|
@ -279,11 +279,11 @@ class MTProtoSender:
|
|||
__log__.debug('Closing current connection...')
|
||||
self._connection.disconnect()
|
||||
|
||||
__log__.debug('Awaiting for the send loop before reconnecting...')
|
||||
await self._send_loop_handle
|
||||
__log__.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
__log__.debug('Awaiting for the receive loop before reconnecting...')
|
||||
await self._recv_loop_handle
|
||||
__log__.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
|
||||
self._reconnecting = False
|
||||
|
||||
|
@ -334,8 +334,7 @@ class MTProtoSender:
|
|||
# This means that while it's not empty we can wait for
|
||||
# more messages to be added to the send queue.
|
||||
try:
|
||||
batch, data = await self._send_queue.get(
|
||||
self._connection.disconnected)
|
||||
batch, data = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user