mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-26 03:13:45 +03:00
Let Connection._disconnected be a proper Future
This means that awaiting on disconnect will properly raise errors, allowing to differentiate clean disconnects from faulty ones.
This commit is contained in:
parent
542d0f539b
commit
6d30a38316
|
@ -26,9 +26,8 @@ class Connection(abc.ABC):
|
|||
self._proxy = proxy
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._disconnected = asyncio.Event(loop=loop)
|
||||
self._disconnected.set()
|
||||
self._disconnected_future = None
|
||||
self._disconnected = self._loop.create_future()
|
||||
self._disconnected.set_result(None)
|
||||
self._send_task = None
|
||||
self._recv_task = None
|
||||
self._send_queue = asyncio.Queue(1)
|
||||
|
@ -76,8 +75,7 @@ class Connection(abc.ABC):
|
|||
self._reader, self._writer = \
|
||||
await asyncio.open_connection(sock=s, loop=self._loop)
|
||||
|
||||
self._disconnected.clear()
|
||||
self._disconnected_future = None
|
||||
self._disconnected = self._loop.create_future()
|
||||
self._send_task = self._loop.create_task(self._send_loop())
|
||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||
|
||||
|
@ -86,7 +84,13 @@ class Connection(abc.ABC):
|
|||
Disconnects from the server, and clears
|
||||
pending outgoing and incoming messages.
|
||||
"""
|
||||
self._disconnected.set()
|
||||
self._disconnect(error=None)
|
||||
|
||||
def _disconnect(self, error):
|
||||
if error:
|
||||
self._disconnected.set_exception(error)
|
||||
else:
|
||||
self._disconnected.set_result(None)
|
||||
|
||||
while not self._send_queue.empty():
|
||||
self._send_queue.get_nowait()
|
||||
|
@ -105,10 +109,7 @@ class Connection(abc.ABC):
|
|||
|
||||
@property
|
||||
def disconnected(self):
|
||||
if not self._disconnected_future:
|
||||
self._disconnected_future = \
|
||||
self._loop.create_task(self._disconnected.wait())
|
||||
return self._disconnected_future
|
||||
return asyncio.shield(self._disconnected, loop=self._loop)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
|
@ -122,7 +123,7 @@ class Connection(abc.ABC):
|
|||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
if self._disconnected.is_set():
|
||||
if self._disconnected.done():
|
||||
raise ConnectionError('Not connected')
|
||||
|
||||
return self._send_queue.put(data)
|
||||
|
@ -133,7 +134,7 @@ class Connection(abc.ABC):
|
|||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
if self._disconnected.is_set():
|
||||
if self._disconnected.done():
|
||||
raise ConnectionError('Not connected')
|
||||
|
||||
result = await self._recv_queue.get()
|
||||
|
@ -147,33 +148,36 @@ class Connection(abc.ABC):
|
|||
This loop is constantly popping items off the queue to send them.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.is_set():
|
||||
while not self._disconnected.done():
|
||||
self._send(await self._send_queue.get())
|
||||
await self._writer.drain()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception('Unhandled exception in the send loop')
|
||||
self.disconnect()
|
||||
msg = 'Unexpected exception in the send loop'
|
||||
logging.exception(msg)
|
||||
self._disconnect(ConnectionError(msg))
|
||||
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
This loop is constantly putting items on the queue as they're read.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.is_set():
|
||||
while not self._disconnected.done():
|
||||
data = await self._recv()
|
||||
await self._recv_queue.put(data)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if isinstance(e, asyncio.IncompleteReadError):
|
||||
logging.info('The server closed the connection')
|
||||
msg = 'The server closed the connection'
|
||||
logging.info(msg)
|
||||
else:
|
||||
logging.exception('Unhandled exception in the receive loop')
|
||||
msg = 'Unexpected exception in the receive loop'
|
||||
logging.exception(msg)
|
||||
|
||||
await self._recv_queue.put(None)
|
||||
self.disconnect()
|
||||
self._disconnect(ConnectionError(msg))
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send(self, data):
|
||||
|
|
Loading…
Reference in New Issue
Block a user