Revert disconnect() to be async again (#1133)

It's the only way to properly clean all background tasks,
which the library makes heavy use for in MTProto/Connection
send and receive loops.

Some parts of the code even relied on the fact that it was
asynchronous (it used to return a future so you could await
it and not be breaking).

It's automatically syncified to reduce the damage of being
a breaking change.
This commit is contained in:
Lonami Exo 2019-03-21 12:21:00 +01:00
parent 8f302bcdb0
commit 04ba2e1fc7
9 changed files with 68 additions and 72 deletions

View File

@ -440,7 +440,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._state = types.updates.State( self._state = types.updates.State(
0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0) 0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
self.disconnect() await self.disconnect()
self.session.delete() self.session.delete()
return True return True
@ -550,9 +550,9 @@ class AuthMethods(MessageParseMethods, UserMethods):
return await self.start() return await self.start()
def __exit__(self, *args): def __exit__(self, *args):
self.disconnect() self.disconnect() # It's also syncified, like start()
async def __aexit__(self, *args): async def __aexit__(self, *args):
self.disconnect() await self.disconnect()
# endregion # endregion

View File

@ -281,7 +281,7 @@ class DownloadMethods(UserMethods):
if exported: if exported:
await self._return_exported_sender(sender) await self._return_exported_sender(sender)
elif sender != self._sender: elif sender != self._sender:
sender.disconnect() await sender.disconnect()
if isinstance(file, str) or in_memory: if isinstance(file, str) or in_memory:
f.close() f.close()

View File

@ -6,7 +6,7 @@ import sys
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from .. import version, __name__ as __base_name__ from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa from ..crypto import rsa
from ..extensions import markdown from ..extensions import markdown
from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
@ -376,28 +376,29 @@ class TelegramBaseClient(abc.ABC):
""" """
Disconnects from Telegram. Disconnects from Telegram.
Returns a dummy completed future with ``None`` as a result so If the event loop is already running, this method returns a
you can ``await`` this method just like every other method for coroutine that you should await on your own code; otherwise
consistency or compatibility. the loop is ran until said coroutine completes.
""" """
self._disconnect() if self._loop.is_running():
return self._disconnect_coro()
else:
self._loop.run_until_complete(self._disconnect_coro())
async def _disconnect_coro(self):
await self._disconnect()
self.session.set_update_state(0, self._state) self.session.set_update_state(0, self._state)
self.session.close() self.session.close()
result = self._loop.create_future() async def _disconnect(self):
result.set_result(None)
return result
def _disconnect(self):
""" """
Disconnect only, without closing the session. Used in reconnections Disconnect only, without closing the session. Used in reconnections
to different data centers, where we don't want to close the session to different data centers, where we don't want to close the session
file; user disconnects however should close it since it means that file; user disconnects however should close it since it means that
their job with the client is complete and we should clean it up all. their job with the client is complete and we should clean it up all.
""" """
self._sender.disconnect() await self._sender.disconnect()
if self._updates_handle: await helpers._cancel(self._log, updates_handle=self._updates_handle)
self._updates_handle.cancel()
async def _switch_dc(self, new_dc): async def _switch_dc(self, new_dc):
""" """
@ -412,7 +413,7 @@ class TelegramBaseClient(abc.ABC):
self._sender.auth_key.key = None self._sender.auth_key.key = None
self.session.auth_key = None self.session.auth_key = None
self.session.save() self.session.save()
self._disconnect() await self._disconnect()
return await self.connect() return await self.connect()
def _auth_key_callback(self, auth_key): def _auth_key_callback(self, auth_key):
@ -515,7 +516,7 @@ class TelegramBaseClient(abc.ABC):
if not n: if not n:
self._log[__name__].info( self._log[__name__].info(
'Disconnecting borrowed sender for DC %d', dc_id) 'Disconnecting borrowed sender for DC %d', dc_id)
sender.disconnect() await sender.disconnect()
async def _get_cdn_client(self, cdn_redirect): async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._borrow_exported_client, but for CDNs""" """Similar to ._borrow_exported_client, but for CDNs"""

View File

@ -16,7 +16,7 @@ class UpdateMethods(UserMethods):
try: try:
await self.disconnected await self.disconnected
except KeyboardInterrupt: except KeyboardInterrupt:
self.disconnect() await self.disconnect()
def run_until_disconnected(self): def run_until_disconnected(self):
""" """
@ -33,7 +33,7 @@ class UpdateMethods(UserMethods):
try: try:
return self.loop.run_until_complete(self.disconnected) return self.loop.run_until_complete(self.disconnected)
except KeyboardInterrupt: except KeyboardInterrupt:
self.disconnect() self.loop.run_until_complete(self.disconnect())
def on(self, event): def on(self, event):
""" """

View File

@ -1,4 +1,5 @@
"""Various helpers not related to the Telegram API itself""" """Various helpers not related to the Telegram API itself"""
import asyncio
import os import os
import struct import struct
from hashlib import sha1, sha256 from hashlib import sha1, sha256
@ -85,6 +86,22 @@ def retry_range(retries):
yield 1 + attempt yield 1 + attempt
async def _cancel(log, **tasks):
"""
Helper to cancel one or more tasks gracefully, logging exceptions.
"""
for name, task in tasks.items():
if not task:
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
log.exception('Unhandled exception from %s after cancel', name)
# endregion # endregion
# region Cryptographic related utils # region Cryptographic related utils

View File

@ -4,6 +4,7 @@ import socket
import ssl as ssl_mod import ssl as ssl_mod
from ...errors import InvalidChecksumError from ...errors import InvalidChecksumError
from ... import helpers
class Connection(abc.ABC): class Connection(abc.ABC):
@ -92,18 +93,18 @@ class Connection(abc.ABC):
self._send_task = self._loop.create_task(self._send_loop()) self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop()) self._recv_task = self._loop.create_task(self._recv_loop())
def disconnect(self): async def disconnect(self):
""" """
Disconnects from the server, and clears Disconnects from the server, and clears
pending outgoing and incoming messages. pending outgoing and incoming messages.
""" """
self._connected = False self._connected = False
if self._send_task: await helpers._cancel(
self._send_task.cancel() self._log,
send_task=self._send_task,
if self._recv_task: recv_task=self._recv_task
self._recv_task.cancel() )
if self._writer: if self._writer:
self._writer.close() self._writer.close()
@ -148,7 +149,7 @@ class Connection(abc.ABC):
else: else:
self._log.exception('Unexpected exception in the send loop') self._log.exception('Unexpected exception in the send loop')
self.disconnect() await self.disconnect()
async def _recv_loop(self): async def _recv_loop(self):
""" """
@ -170,7 +171,7 @@ class Connection(abc.ABC):
msg = 'Unexpected exception in the receive loop' msg = 'Unexpected exception in the receive loop'
self._log.exception(msg) self._log.exception(msg)
self.disconnect() await self.disconnect()
# Add a sentinel value to unstuck recv # Add a sentinel value to unstuck recv
if self._recv_queue.empty(): if self._recv_queue.empty():

View File

@ -111,7 +111,7 @@ class TcpMTProxy(ObfuscatedConnection):
await asyncio.sleep(2) await asyncio.sleep(2)
if self._reader.at_eof(): if self._reader.at_eof():
self.disconnect() await self.disconnect()
raise ConnectionError( raise ConnectionError(
'Proxy closed the connection after sending initial payload') 'Proxy closed the connection after sending initial payload')

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import collections import collections
import functools
from . import authenticator from . import authenticator
from ..extensions.messagepacker import MessagePacker from ..extensions.messagepacker import MessagePacker
@ -8,7 +7,7 @@ from .mtprotoplainsender import MTProtoPlainSender
from .requeststate import RequestState from .requeststate import RequestState
from .mtprotostate import MTProtoState from .mtprotostate import MTProtoState
from ..tl.tlobject import TLRequest from ..tl.tlobject import TLRequest
from .. import utils from .. import helpers, utils
from ..errors import ( from ..errors import (
BadMessageError, InvalidBufferError, SecurityError, BadMessageError, InvalidBufferError, SecurityError,
TypeNotFoundError, rpc_message_to_error TypeNotFoundError, rpc_message_to_error
@ -25,23 +24,6 @@ from ..crypto import AuthKey
from ..helpers import retry_range from ..helpers import retry_range
def _cancellable(func):
"""
Silences `asyncio.CancelledError` for an entire function.
This way the function can be cancelled without the task ending
with a exception, and without the function body requiring another
indent level for the try/except.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except asyncio.CancelledError:
pass
return wrapped
class MTProtoSender: class MTProtoSender:
""" """
MTProto Mobile Protocol sender MTProto Mobile Protocol sender
@ -143,12 +125,12 @@ class MTProtoSender:
def is_connected(self): def is_connected(self):
return self._user_connected return self._user_connected
def disconnect(self): async def disconnect(self):
""" """
Cleanly disconnects the instance from the network, cancels Cleanly disconnects the instance from the network, cancels
all pending requests, and closes the send and receive loops. all pending requests, and closes the send and receive loops.
""" """
self._disconnect() await self._disconnect()
def send(self, request, ordered=False): def send(self, request, ordered=False):
""" """
@ -251,7 +233,7 @@ class MTProtoSender:
else: else:
e = ConnectionError('auth_key generation failed {} time(s)' e = ConnectionError('auth_key generation failed {} time(s)'
.format(attempt)) .format(attempt))
self._disconnect(error=e) await self._disconnect(error=e)
raise e raise e
self._log.debug('Starting send loop') self._log.debug('Starting send loop')
@ -268,12 +250,12 @@ class MTProtoSender:
self._log.info('Connection to %s complete!', self._connection) self._log.info('Connection to %s complete!', self._connection)
def _disconnect(self, error=None): async def _disconnect(self, error=None):
self._log.info('Disconnecting from %s...', self._connection) self._log.info('Disconnecting from %s...', self._connection)
self._user_connected = False self._user_connected = False
try: try:
self._log.debug('Closing current connection...') self._log.debug('Closing current connection...')
self._connection.disconnect() await self._connection.disconnect()
finally: finally:
self._log.debug('Cancelling {} pending message(s)...' self._log.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_state))) .format(len(self._pending_state)))
@ -286,14 +268,11 @@ class MTProtoSender:
self._pending_state.clear() self._pending_state.clear()
self._pending_ack.clear() self._pending_ack.clear()
self._last_ack = None self._last_ack = None
await helpers._cancel(
if self._send_loop_handle: self._log,
self._log.debug('Cancelling the send loop...') send_loop_handle=self._send_loop_handle,
self._send_loop_handle.cancel() recv_loop_handle=self._recv_loop_handle
)
if self._recv_loop_handle:
self._log.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
self._log.info('Disconnection from %s complete!', self._connection) self._log.info('Disconnection from %s complete!', self._connection)
if self._disconnected and not self._disconnected.done(): if self._disconnected and not self._disconnected.done():
@ -309,13 +288,13 @@ class MTProtoSender:
self._reconnecting = True self._reconnecting = True
self._log.debug('Closing current connection...') self._log.debug('Closing current connection...')
self._connection.disconnect() await self._connection.disconnect()
self._log.debug('Cancelling the send loop...') await helpers._cancel(
self._send_loop_handle.cancel() self._log,
send_loop_handle=self._send_loop_handle,
self._log.debug('Cancelling the receive loop...') recv_loop_handle=self._recv_loop_handle
self._recv_loop_handle.cancel() )
self._reconnecting = False self._reconnecting = False
@ -347,7 +326,7 @@ class MTProtoSender:
else: else:
self._log.error('Automatic reconnection failed {} time(s)' self._log.error('Automatic reconnection failed {} time(s)'
.format(attempt)) .format(attempt))
self._disconnect(error=ConnectionError()) await self._disconnect(error=ConnectionError())
def _start_reconnect(self): def _start_reconnect(self):
"""Starts a reconnection in the background.""" """Starts a reconnection in the background."""
@ -356,7 +335,6 @@ class MTProtoSender:
# Loops # Loops
@_cancellable
async def _send_loop(self): async def _send_loop(self):
""" """
This loop is responsible for popping items off the send This loop is responsible for popping items off the send
@ -402,7 +380,6 @@ class MTProtoSender:
self._log.debug('Encrypted messages put in a queue to be sent') self._log.debug('Encrypted messages put in a queue to be sent')
@_cancellable
async def _recv_loop(self): async def _recv_loop(self):
""" """
This loop is responsible for reading all incoming responses This loop is responsible for reading all incoming responses

View File

@ -365,7 +365,7 @@ async def main(loop, interval=0.05):
if 'application has been destroyed' not in e.args[0]: if 'application has been destroyed' not in e.args[0]:
raise raise
finally: finally:
app.cl.disconnect() await app.cl.disconnect()
if __name__ == "__main__": if __name__ == "__main__":