From 04ba2e1fc7932fb6f0136c1e681d585fb452ce92 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 21 Mar 2019 12:21:00 +0100 Subject: [PATCH] Revert disconnect() to be async again (#1133) It's the only way to properly clean all background tasks, which the library makes heavy use for in MTProto/Connection send and receive loops. Some parts of the code even relied on the fact that it was asynchronous (it used to return a future so you could await it and not be breaking). It's automatically syncified to reduce the damage of being a breaking change. --- telethon/client/auth.py | 6 +-- telethon/client/downloads.py | 2 +- telethon/client/telegrambaseclient.py | 31 ++++++------ telethon/client/updates.py | 4 +- telethon/helpers.py | 17 +++++++ telethon/network/connection/connection.py | 17 ++++--- telethon/network/connection/tcpmtproxy.py | 2 +- telethon/network/mtprotosender.py | 59 +++++++---------------- telethon_examples/gui.py | 2 +- 9 files changed, 68 insertions(+), 72 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index d885df35..36afc343 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -440,7 +440,7 @@ class AuthMethods(MessageParseMethods, UserMethods): self._state = types.updates.State( 0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0) - self.disconnect() + await self.disconnect() self.session.delete() return True @@ -550,9 +550,9 @@ class AuthMethods(MessageParseMethods, UserMethods): return await self.start() def __exit__(self, *args): - self.disconnect() + self.disconnect() # It's also syncified, like start() async def __aexit__(self, *args): - self.disconnect() + await self.disconnect() # endregion diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 8de9016a..ec7e5628 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -281,7 +281,7 @@ class DownloadMethods(UserMethods): if exported: await self._return_exported_sender(sender) elif sender != self._sender: - sender.disconnect() + await sender.disconnect() if isinstance(file, str) or in_memory: f.close() diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index c1179ff0..83648af9 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -6,7 +6,7 @@ import sys import time from datetime import datetime, timezone -from .. import version, __name__ as __base_name__ +from .. import version, helpers, __name__ as __base_name__ from ..crypto import rsa from ..extensions import markdown from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy @@ -376,28 +376,29 @@ class TelegramBaseClient(abc.ABC): """ Disconnects from Telegram. - Returns a dummy completed future with ``None`` as a result so - you can ``await`` this method just like every other method for - consistency or compatibility. + If the event loop is already running, this method returns a + coroutine that you should await on your own code; otherwise + the loop is ran until said coroutine completes. """ - self._disconnect() + if self._loop.is_running(): + return self._disconnect_coro() + else: + self._loop.run_until_complete(self._disconnect_coro()) + + async def _disconnect_coro(self): + await self._disconnect() self.session.set_update_state(0, self._state) self.session.close() - result = self._loop.create_future() - result.set_result(None) - return result - - def _disconnect(self): + async def _disconnect(self): """ Disconnect only, without closing the session. Used in reconnections to different data centers, where we don't want to close the session file; user disconnects however should close it since it means that their job with the client is complete and we should clean it up all. """ - self._sender.disconnect() - if self._updates_handle: - self._updates_handle.cancel() + await self._sender.disconnect() + await helpers._cancel(self._log, updates_handle=self._updates_handle) async def _switch_dc(self, new_dc): """ @@ -412,7 +413,7 @@ class TelegramBaseClient(abc.ABC): self._sender.auth_key.key = None self.session.auth_key = None self.session.save() - self._disconnect() + await self._disconnect() return await self.connect() def _auth_key_callback(self, auth_key): @@ -515,7 +516,7 @@ class TelegramBaseClient(abc.ABC): if not n: self._log[__name__].info( 'Disconnecting borrowed sender for DC %d', dc_id) - sender.disconnect() + await sender.disconnect() async def _get_cdn_client(self, cdn_redirect): """Similar to ._borrow_exported_client, but for CDNs""" diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 6b4cb520..ca42dfed 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -16,7 +16,7 @@ class UpdateMethods(UserMethods): try: await self.disconnected except KeyboardInterrupt: - self.disconnect() + await self.disconnect() def run_until_disconnected(self): """ @@ -33,7 +33,7 @@ class UpdateMethods(UserMethods): try: return self.loop.run_until_complete(self.disconnected) except KeyboardInterrupt: - self.disconnect() + self.loop.run_until_complete(self.disconnect()) def on(self, event): """ diff --git a/telethon/helpers.py b/telethon/helpers.py index 4714a5c4..33276b8f 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -1,4 +1,5 @@ """Various helpers not related to the Telegram API itself""" +import asyncio import os import struct from hashlib import sha1, sha256 @@ -85,6 +86,22 @@ def retry_range(retries): yield 1 + attempt +async def _cancel(log, **tasks): + """ + Helper to cancel one or more tasks gracefully, logging exceptions. + """ + for name, task in tasks.items(): + if not task: + continue + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + log.exception('Unhandled exception from %s after cancel', name) + # endregion # region Cryptographic related utils diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index 2987bc11..e9a948c3 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -4,6 +4,7 @@ import socket import ssl as ssl_mod from ...errors import InvalidChecksumError +from ... import helpers class Connection(abc.ABC): @@ -92,18 +93,18 @@ class Connection(abc.ABC): self._send_task = self._loop.create_task(self._send_loop()) self._recv_task = self._loop.create_task(self._recv_loop()) - def disconnect(self): + async def disconnect(self): """ Disconnects from the server, and clears pending outgoing and incoming messages. """ self._connected = False - if self._send_task: - self._send_task.cancel() - - if self._recv_task: - self._recv_task.cancel() + await helpers._cancel( + self._log, + send_task=self._send_task, + recv_task=self._recv_task + ) if self._writer: self._writer.close() @@ -148,7 +149,7 @@ class Connection(abc.ABC): else: self._log.exception('Unexpected exception in the send loop') - self.disconnect() + await self.disconnect() async def _recv_loop(self): """ @@ -170,7 +171,7 @@ class Connection(abc.ABC): msg = 'Unexpected exception in the receive loop' self._log.exception(msg) - self.disconnect() + await self.disconnect() # Add a sentinel value to unstuck recv if self._recv_queue.empty(): diff --git a/telethon/network/connection/tcpmtproxy.py b/telethon/network/connection/tcpmtproxy.py index ac3d2ba5..e1798b61 100644 --- a/telethon/network/connection/tcpmtproxy.py +++ b/telethon/network/connection/tcpmtproxy.py @@ -111,7 +111,7 @@ class TcpMTProxy(ObfuscatedConnection): await asyncio.sleep(2) if self._reader.at_eof(): - self.disconnect() + await self.disconnect() raise ConnectionError( 'Proxy closed the connection after sending initial payload') diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index b36d5e2b..cc858665 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -1,6 +1,5 @@ import asyncio import collections -import functools from . import authenticator from ..extensions.messagepacker import MessagePacker @@ -8,7 +7,7 @@ from .mtprotoplainsender import MTProtoPlainSender from .requeststate import RequestState from .mtprotostate import MTProtoState from ..tl.tlobject import TLRequest -from .. import utils +from .. import helpers, utils from ..errors import ( BadMessageError, InvalidBufferError, SecurityError, TypeNotFoundError, rpc_message_to_error @@ -25,23 +24,6 @@ from ..crypto import AuthKey from ..helpers import retry_range -def _cancellable(func): - """ - Silences `asyncio.CancelledError` for an entire function. - - This way the function can be cancelled without the task ending - with a exception, and without the function body requiring another - indent level for the try/except. - """ - @functools.wraps(func) - def wrapped(*args, **kwargs): - try: - return func(*args, **kwargs) - except asyncio.CancelledError: - pass - return wrapped - - class MTProtoSender: """ MTProto Mobile Protocol sender @@ -143,12 +125,12 @@ class MTProtoSender: def is_connected(self): return self._user_connected - def disconnect(self): + async def disconnect(self): """ Cleanly disconnects the instance from the network, cancels all pending requests, and closes the send and receive loops. """ - self._disconnect() + await self._disconnect() def send(self, request, ordered=False): """ @@ -251,7 +233,7 @@ class MTProtoSender: else: e = ConnectionError('auth_key generation failed {} time(s)' .format(attempt)) - self._disconnect(error=e) + await self._disconnect(error=e) raise e self._log.debug('Starting send loop') @@ -268,12 +250,12 @@ class MTProtoSender: self._log.info('Connection to %s complete!', self._connection) - def _disconnect(self, error=None): + async def _disconnect(self, error=None): self._log.info('Disconnecting from %s...', self._connection) self._user_connected = False try: self._log.debug('Closing current connection...') - self._connection.disconnect() + await self._connection.disconnect() finally: self._log.debug('Cancelling {} pending message(s)...' .format(len(self._pending_state))) @@ -286,14 +268,11 @@ class MTProtoSender: self._pending_state.clear() self._pending_ack.clear() self._last_ack = None - - if self._send_loop_handle: - self._log.debug('Cancelling the send loop...') - self._send_loop_handle.cancel() - - if self._recv_loop_handle: - self._log.debug('Cancelling the receive loop...') - self._recv_loop_handle.cancel() + await helpers._cancel( + self._log, + send_loop_handle=self._send_loop_handle, + recv_loop_handle=self._recv_loop_handle + ) self._log.info('Disconnection from %s complete!', self._connection) if self._disconnected and not self._disconnected.done(): @@ -309,13 +288,13 @@ class MTProtoSender: self._reconnecting = True self._log.debug('Closing current connection...') - self._connection.disconnect() + await self._connection.disconnect() - self._log.debug('Cancelling the send loop...') - self._send_loop_handle.cancel() - - self._log.debug('Cancelling the receive loop...') - self._recv_loop_handle.cancel() + await helpers._cancel( + self._log, + send_loop_handle=self._send_loop_handle, + recv_loop_handle=self._recv_loop_handle + ) self._reconnecting = False @@ -347,7 +326,7 @@ class MTProtoSender: else: self._log.error('Automatic reconnection failed {} time(s)' .format(attempt)) - self._disconnect(error=ConnectionError()) + await self._disconnect(error=ConnectionError()) def _start_reconnect(self): """Starts a reconnection in the background.""" @@ -356,7 +335,6 @@ class MTProtoSender: # Loops - @_cancellable async def _send_loop(self): """ This loop is responsible for popping items off the send @@ -402,7 +380,6 @@ class MTProtoSender: self._log.debug('Encrypted messages put in a queue to be sent') - @_cancellable async def _recv_loop(self): """ This loop is responsible for reading all incoming responses diff --git a/telethon_examples/gui.py b/telethon_examples/gui.py index 7c48142a..949d1eb9 100644 --- a/telethon_examples/gui.py +++ b/telethon_examples/gui.py @@ -365,7 +365,7 @@ async def main(loop, interval=0.05): if 'application has been destroyed' not in e.args[0]: raise finally: - app.cl.disconnect() + await app.cl.disconnect() if __name__ == "__main__":