Revert disconnect() to be async again (#1133)

It's the only way to properly clean all background tasks,
which the library makes heavy use for in MTProto/Connection
send and receive loops.

Some parts of the code even relied on the fact that it was
asynchronous (it used to return a future so you could await
it and not be breaking).

It's automatically syncified to reduce the damage of being
a breaking change.
This commit is contained in:
Lonami Exo 2019-03-21 12:21:00 +01:00
parent 8f302bcdb0
commit 04ba2e1fc7
9 changed files with 68 additions and 72 deletions

View File

@ -440,7 +440,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._state = types.updates.State(
0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
self.disconnect()
await self.disconnect()
self.session.delete()
return True
@ -550,9 +550,9 @@ class AuthMethods(MessageParseMethods, UserMethods):
return await self.start()
def __exit__(self, *args):
self.disconnect()
self.disconnect() # It's also syncified, like start()
async def __aexit__(self, *args):
self.disconnect()
await self.disconnect()
# endregion

View File

@ -281,7 +281,7 @@ class DownloadMethods(UserMethods):
if exported:
await self._return_exported_sender(sender)
elif sender != self._sender:
sender.disconnect()
await sender.disconnect()
if isinstance(file, str) or in_memory:
f.close()

View File

@ -6,7 +6,7 @@ import sys
import time
from datetime import datetime, timezone
from .. import version, __name__ as __base_name__
from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa
from ..extensions import markdown
from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
@ -376,28 +376,29 @@ class TelegramBaseClient(abc.ABC):
"""
Disconnects from Telegram.
Returns a dummy completed future with ``None`` as a result so
you can ``await`` this method just like every other method for
consistency or compatibility.
If the event loop is already running, this method returns a
coroutine that you should await on your own code; otherwise
the loop is ran until said coroutine completes.
"""
self._disconnect()
if self._loop.is_running():
return self._disconnect_coro()
else:
self._loop.run_until_complete(self._disconnect_coro())
async def _disconnect_coro(self):
await self._disconnect()
self.session.set_update_state(0, self._state)
self.session.close()
result = self._loop.create_future()
result.set_result(None)
return result
def _disconnect(self):
async def _disconnect(self):
"""
Disconnect only, without closing the session. Used in reconnections
to different data centers, where we don't want to close the session
file; user disconnects however should close it since it means that
their job with the client is complete and we should clean it up all.
"""
self._sender.disconnect()
if self._updates_handle:
self._updates_handle.cancel()
await self._sender.disconnect()
await helpers._cancel(self._log, updates_handle=self._updates_handle)
async def _switch_dc(self, new_dc):
"""
@ -412,7 +413,7 @@ class TelegramBaseClient(abc.ABC):
self._sender.auth_key.key = None
self.session.auth_key = None
self.session.save()
self._disconnect()
await self._disconnect()
return await self.connect()
def _auth_key_callback(self, auth_key):
@ -515,7 +516,7 @@ class TelegramBaseClient(abc.ABC):
if not n:
self._log[__name__].info(
'Disconnecting borrowed sender for DC %d', dc_id)
sender.disconnect()
await sender.disconnect()
async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._borrow_exported_client, but for CDNs"""

View File

@ -16,7 +16,7 @@ class UpdateMethods(UserMethods):
try:
await self.disconnected
except KeyboardInterrupt:
self.disconnect()
await self.disconnect()
def run_until_disconnected(self):
"""
@ -33,7 +33,7 @@ class UpdateMethods(UserMethods):
try:
return self.loop.run_until_complete(self.disconnected)
except KeyboardInterrupt:
self.disconnect()
self.loop.run_until_complete(self.disconnect())
def on(self, event):
"""

View File

@ -1,4 +1,5 @@
"""Various helpers not related to the Telegram API itself"""
import asyncio
import os
import struct
from hashlib import sha1, sha256
@ -85,6 +86,22 @@ def retry_range(retries):
yield 1 + attempt
async def _cancel(log, **tasks):
"""
Helper to cancel one or more tasks gracefully, logging exceptions.
"""
for name, task in tasks.items():
if not task:
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
log.exception('Unhandled exception from %s after cancel', name)
# endregion
# region Cryptographic related utils

View File

@ -4,6 +4,7 @@ import socket
import ssl as ssl_mod
from ...errors import InvalidChecksumError
from ... import helpers
class Connection(abc.ABC):
@ -92,18 +93,18 @@ class Connection(abc.ABC):
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
def disconnect(self):
async def disconnect(self):
"""
Disconnects from the server, and clears
pending outgoing and incoming messages.
"""
self._connected = False
if self._send_task:
self._send_task.cancel()
if self._recv_task:
self._recv_task.cancel()
await helpers._cancel(
self._log,
send_task=self._send_task,
recv_task=self._recv_task
)
if self._writer:
self._writer.close()
@ -148,7 +149,7 @@ class Connection(abc.ABC):
else:
self._log.exception('Unexpected exception in the send loop')
self.disconnect()
await self.disconnect()
async def _recv_loop(self):
"""
@ -170,7 +171,7 @@ class Connection(abc.ABC):
msg = 'Unexpected exception in the receive loop'
self._log.exception(msg)
self.disconnect()
await self.disconnect()
# Add a sentinel value to unstuck recv
if self._recv_queue.empty():

View File

@ -111,7 +111,7 @@ class TcpMTProxy(ObfuscatedConnection):
await asyncio.sleep(2)
if self._reader.at_eof():
self.disconnect()
await self.disconnect()
raise ConnectionError(
'Proxy closed the connection after sending initial payload')

View File

@ -1,6 +1,5 @@
import asyncio
import collections
import functools
from . import authenticator
from ..extensions.messagepacker import MessagePacker
@ -8,7 +7,7 @@ from .mtprotoplainsender import MTProtoPlainSender
from .requeststate import RequestState
from .mtprotostate import MTProtoState
from ..tl.tlobject import TLRequest
from .. import utils
from .. import helpers, utils
from ..errors import (
BadMessageError, InvalidBufferError, SecurityError,
TypeNotFoundError, rpc_message_to_error
@ -25,23 +24,6 @@ from ..crypto import AuthKey
from ..helpers import retry_range
def _cancellable(func):
"""
Silences `asyncio.CancelledError` for an entire function.
This way the function can be cancelled without the task ending
with a exception, and without the function body requiring another
indent level for the try/except.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except asyncio.CancelledError:
pass
return wrapped
class MTProtoSender:
"""
MTProto Mobile Protocol sender
@ -143,12 +125,12 @@ class MTProtoSender:
def is_connected(self):
return self._user_connected
def disconnect(self):
async def disconnect(self):
"""
Cleanly disconnects the instance from the network, cancels
all pending requests, and closes the send and receive loops.
"""
self._disconnect()
await self._disconnect()
def send(self, request, ordered=False):
"""
@ -251,7 +233,7 @@ class MTProtoSender:
else:
e = ConnectionError('auth_key generation failed {} time(s)'
.format(attempt))
self._disconnect(error=e)
await self._disconnect(error=e)
raise e
self._log.debug('Starting send loop')
@ -268,12 +250,12 @@ class MTProtoSender:
self._log.info('Connection to %s complete!', self._connection)
def _disconnect(self, error=None):
async def _disconnect(self, error=None):
self._log.info('Disconnecting from %s...', self._connection)
self._user_connected = False
try:
self._log.debug('Closing current connection...')
self._connection.disconnect()
await self._connection.disconnect()
finally:
self._log.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_state)))
@ -286,14 +268,11 @@ class MTProtoSender:
self._pending_state.clear()
self._pending_ack.clear()
self._last_ack = None
if self._send_loop_handle:
self._log.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
if self._recv_loop_handle:
self._log.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
await helpers._cancel(
self._log,
send_loop_handle=self._send_loop_handle,
recv_loop_handle=self._recv_loop_handle
)
self._log.info('Disconnection from %s complete!', self._connection)
if self._disconnected and not self._disconnected.done():
@ -309,13 +288,13 @@ class MTProtoSender:
self._reconnecting = True
self._log.debug('Closing current connection...')
self._connection.disconnect()
await self._connection.disconnect()
self._log.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
self._log.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
await helpers._cancel(
self._log,
send_loop_handle=self._send_loop_handle,
recv_loop_handle=self._recv_loop_handle
)
self._reconnecting = False
@ -347,7 +326,7 @@ class MTProtoSender:
else:
self._log.error('Automatic reconnection failed {} time(s)'
.format(attempt))
self._disconnect(error=ConnectionError())
await self._disconnect(error=ConnectionError())
def _start_reconnect(self):
"""Starts a reconnection in the background."""
@ -356,7 +335,6 @@ class MTProtoSender:
# Loops
@_cancellable
async def _send_loop(self):
"""
This loop is responsible for popping items off the send
@ -402,7 +380,6 @@ class MTProtoSender:
self._log.debug('Encrypted messages put in a queue to be sent')
@_cancellable
async def _recv_loop(self):
"""
This loop is responsible for reading all incoming responses

View File

@ -365,7 +365,7 @@ async def main(loop, interval=0.05):
if 'application has been destroyed' not in e.args[0]:
raise
finally:
app.cl.disconnect()
await app.cl.disconnect()
if __name__ == "__main__":