mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-18 12:30:59 +03:00
Let Connection._disconnected be a proper Future
This means that awaiting on disconnect will properly raise errors, allowing to differentiate clean disconnects from faulty ones.
This commit is contained in:
parent
542d0f539b
commit
6d30a38316
|
@ -26,9 +26,8 @@ class Connection(abc.ABC):
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
self._disconnected = asyncio.Event(loop=loop)
|
self._disconnected = self._loop.create_future()
|
||||||
self._disconnected.set()
|
self._disconnected.set_result(None)
|
||||||
self._disconnected_future = None
|
|
||||||
self._send_task = None
|
self._send_task = None
|
||||||
self._recv_task = None
|
self._recv_task = None
|
||||||
self._send_queue = asyncio.Queue(1)
|
self._send_queue = asyncio.Queue(1)
|
||||||
|
@ -76,8 +75,7 @@ class Connection(abc.ABC):
|
||||||
self._reader, self._writer = \
|
self._reader, self._writer = \
|
||||||
await asyncio.open_connection(sock=s, loop=self._loop)
|
await asyncio.open_connection(sock=s, loop=self._loop)
|
||||||
|
|
||||||
self._disconnected.clear()
|
self._disconnected = self._loop.create_future()
|
||||||
self._disconnected_future = None
|
|
||||||
self._send_task = self._loop.create_task(self._send_loop())
|
self._send_task = self._loop.create_task(self._send_loop())
|
||||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||||
|
|
||||||
|
@ -86,7 +84,13 @@ class Connection(abc.ABC):
|
||||||
Disconnects from the server, and clears
|
Disconnects from the server, and clears
|
||||||
pending outgoing and incoming messages.
|
pending outgoing and incoming messages.
|
||||||
"""
|
"""
|
||||||
self._disconnected.set()
|
self._disconnect(error=None)
|
||||||
|
|
||||||
|
def _disconnect(self, error):
|
||||||
|
if error:
|
||||||
|
self._disconnected.set_exception(error)
|
||||||
|
else:
|
||||||
|
self._disconnected.set_result(None)
|
||||||
|
|
||||||
while not self._send_queue.empty():
|
while not self._send_queue.empty():
|
||||||
self._send_queue.get_nowait()
|
self._send_queue.get_nowait()
|
||||||
|
@ -105,10 +109,7 @@ class Connection(abc.ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disconnected(self):
|
def disconnected(self):
|
||||||
if not self._disconnected_future:
|
return asyncio.shield(self._disconnected, loop=self._loop)
|
||||||
self._disconnected_future = \
|
|
||||||
self._loop.create_task(self._disconnected.wait())
|
|
||||||
return self._disconnected_future
|
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""
|
"""
|
||||||
|
@ -122,7 +123,7 @@ class Connection(abc.ABC):
|
||||||
|
|
||||||
This method returns a coroutine.
|
This method returns a coroutine.
|
||||||
"""
|
"""
|
||||||
if self._disconnected.is_set():
|
if self._disconnected.done():
|
||||||
raise ConnectionError('Not connected')
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
return self._send_queue.put(data)
|
return self._send_queue.put(data)
|
||||||
|
@ -133,7 +134,7 @@ class Connection(abc.ABC):
|
||||||
|
|
||||||
This method returns a coroutine.
|
This method returns a coroutine.
|
||||||
"""
|
"""
|
||||||
if self._disconnected.is_set():
|
if self._disconnected.done():
|
||||||
raise ConnectionError('Not connected')
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
result = await self._recv_queue.get()
|
result = await self._recv_queue.get()
|
||||||
|
@ -147,33 +148,36 @@ class Connection(abc.ABC):
|
||||||
This loop is constantly popping items off the queue to send them.
|
This loop is constantly popping items off the queue to send them.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while not self._disconnected.is_set():
|
while not self._disconnected.done():
|
||||||
self._send(await self._send_queue.get())
|
self._send(await self._send_queue.get())
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception('Unhandled exception in the send loop')
|
msg = 'Unexpected exception in the send loop'
|
||||||
self.disconnect()
|
logging.exception(msg)
|
||||||
|
self._disconnect(ConnectionError(msg))
|
||||||
|
|
||||||
async def _recv_loop(self):
|
async def _recv_loop(self):
|
||||||
"""
|
"""
|
||||||
This loop is constantly putting items on the queue as they're read.
|
This loop is constantly putting items on the queue as they're read.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while not self._disconnected.is_set():
|
while not self._disconnected.done():
|
||||||
data = await self._recv()
|
data = await self._recv()
|
||||||
await self._recv_queue.put(data)
|
await self._recv_queue.put(data)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, asyncio.IncompleteReadError):
|
if isinstance(e, asyncio.IncompleteReadError):
|
||||||
logging.info('The server closed the connection')
|
msg = 'The server closed the connection'
|
||||||
|
logging.info(msg)
|
||||||
else:
|
else:
|
||||||
logging.exception('Unhandled exception in the receive loop')
|
msg = 'Unexpected exception in the receive loop'
|
||||||
|
logging.exception(msg)
|
||||||
|
|
||||||
await self._recv_queue.put(None)
|
await self._recv_queue.put(None)
|
||||||
self.disconnect()
|
self._disconnect(ConnectionError(msg))
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _send(self, data):
|
def _send(self, data):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user