mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-25 10:53:44 +03:00
Revert disconnect() to be async again (#1133)
It's the only way to properly clean all background tasks, which the library makes heavy use for in MTProto/Connection send and receive loops. Some parts of the code even relied on the fact that it was asynchronous (it used to return a future so you could await it and not be breaking). It's automatically syncified to reduce the damage of being a breaking change.
This commit is contained in:
parent
8f302bcdb0
commit
04ba2e1fc7
|
@ -440,7 +440,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
|
|||
self._state = types.updates.State(
|
||||
0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
|
||||
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
self.session.delete()
|
||||
return True
|
||||
|
||||
|
@ -550,9 +550,9 @@ class AuthMethods(MessageParseMethods, UserMethods):
|
|||
return await self.start()
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.disconnect()
|
||||
self.disconnect() # It's also syncified, like start()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -281,7 +281,7 @@ class DownloadMethods(UserMethods):
|
|||
if exported:
|
||||
await self._return_exported_sender(sender)
|
||||
elif sender != self._sender:
|
||||
sender.disconnect()
|
||||
await sender.disconnect()
|
||||
if isinstance(file, str) or in_memory:
|
||||
f.close()
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .. import version, __name__ as __base_name__
|
||||
from .. import version, helpers, __name__ as __base_name__
|
||||
from ..crypto import rsa
|
||||
from ..extensions import markdown
|
||||
from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
|
||||
|
@ -376,28 +376,29 @@ class TelegramBaseClient(abc.ABC):
|
|||
"""
|
||||
Disconnects from Telegram.
|
||||
|
||||
Returns a dummy completed future with ``None`` as a result so
|
||||
you can ``await`` this method just like every other method for
|
||||
consistency or compatibility.
|
||||
If the event loop is already running, this method returns a
|
||||
coroutine that you should await on your own code; otherwise
|
||||
the loop is ran until said coroutine completes.
|
||||
"""
|
||||
self._disconnect()
|
||||
if self._loop.is_running():
|
||||
return self._disconnect_coro()
|
||||
else:
|
||||
self._loop.run_until_complete(self._disconnect_coro())
|
||||
|
||||
async def _disconnect_coro(self):
|
||||
await self._disconnect()
|
||||
self.session.set_update_state(0, self._state)
|
||||
self.session.close()
|
||||
|
||||
result = self._loop.create_future()
|
||||
result.set_result(None)
|
||||
return result
|
||||
|
||||
def _disconnect(self):
|
||||
async def _disconnect(self):
|
||||
"""
|
||||
Disconnect only, without closing the session. Used in reconnections
|
||||
to different data centers, where we don't want to close the session
|
||||
file; user disconnects however should close it since it means that
|
||||
their job with the client is complete and we should clean it up all.
|
||||
"""
|
||||
self._sender.disconnect()
|
||||
if self._updates_handle:
|
||||
self._updates_handle.cancel()
|
||||
await self._sender.disconnect()
|
||||
await helpers._cancel(self._log, updates_handle=self._updates_handle)
|
||||
|
||||
async def _switch_dc(self, new_dc):
|
||||
"""
|
||||
|
@ -412,7 +413,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
self._sender.auth_key.key = None
|
||||
self.session.auth_key = None
|
||||
self.session.save()
|
||||
self._disconnect()
|
||||
await self._disconnect()
|
||||
return await self.connect()
|
||||
|
||||
def _auth_key_callback(self, auth_key):
|
||||
|
@ -515,7 +516,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
if not n:
|
||||
self._log[__name__].info(
|
||||
'Disconnecting borrowed sender for DC %d', dc_id)
|
||||
sender.disconnect()
|
||||
await sender.disconnect()
|
||||
|
||||
async def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._borrow_exported_client, but for CDNs"""
|
||||
|
|
|
@ -16,7 +16,7 @@ class UpdateMethods(UserMethods):
|
|||
try:
|
||||
await self.disconnected
|
||||
except KeyboardInterrupt:
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
|
||||
def run_until_disconnected(self):
|
||||
"""
|
||||
|
@ -33,7 +33,7 @@ class UpdateMethods(UserMethods):
|
|||
try:
|
||||
return self.loop.run_until_complete(self.disconnected)
|
||||
except KeyboardInterrupt:
|
||||
self.disconnect()
|
||||
self.loop.run_until_complete(self.disconnect())
|
||||
|
||||
def on(self, event):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Various helpers not related to the Telegram API itself"""
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from hashlib import sha1, sha256
|
||||
|
@ -85,6 +86,22 @@ def retry_range(retries):
|
|||
yield 1 + attempt
|
||||
|
||||
|
||||
async def _cancel(log, **tasks):
|
||||
"""
|
||||
Helper to cancel one or more tasks gracefully, logging exceptions.
|
||||
"""
|
||||
for name, task in tasks.items():
|
||||
if not task:
|
||||
continue
|
||||
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception('Unhandled exception from %s after cancel', name)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Cryptographic related utils
|
||||
|
|
|
@ -4,6 +4,7 @@ import socket
|
|||
import ssl as ssl_mod
|
||||
|
||||
from ...errors import InvalidChecksumError
|
||||
from ... import helpers
|
||||
|
||||
|
||||
class Connection(abc.ABC):
|
||||
|
@ -92,18 +93,18 @@ class Connection(abc.ABC):
|
|||
self._send_task = self._loop.create_task(self._send_loop())
|
||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||
|
||||
def disconnect(self):
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Disconnects from the server, and clears
|
||||
pending outgoing and incoming messages.
|
||||
"""
|
||||
self._connected = False
|
||||
|
||||
if self._send_task:
|
||||
self._send_task.cancel()
|
||||
|
||||
if self._recv_task:
|
||||
self._recv_task.cancel()
|
||||
await helpers._cancel(
|
||||
self._log,
|
||||
send_task=self._send_task,
|
||||
recv_task=self._recv_task
|
||||
)
|
||||
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
|
@ -148,7 +149,7 @@ class Connection(abc.ABC):
|
|||
else:
|
||||
self._log.exception('Unexpected exception in the send loop')
|
||||
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
|
@ -170,7 +171,7 @@ class Connection(abc.ABC):
|
|||
msg = 'Unexpected exception in the receive loop'
|
||||
self._log.exception(msg)
|
||||
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
|
||||
# Add a sentinel value to unstuck recv
|
||||
if self._recv_queue.empty():
|
||||
|
|
|
@ -111,7 +111,7 @@ class TcpMTProxy(ObfuscatedConnection):
|
|||
await asyncio.sleep(2)
|
||||
|
||||
if self._reader.at_eof():
|
||||
self.disconnect()
|
||||
await self.disconnect()
|
||||
raise ConnectionError(
|
||||
'Proxy closed the connection after sending initial payload')
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import functools
|
||||
|
||||
from . import authenticator
|
||||
from ..extensions.messagepacker import MessagePacker
|
||||
|
@ -8,7 +7,7 @@ from .mtprotoplainsender import MTProtoPlainSender
|
|||
from .requeststate import RequestState
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..tl.tlobject import TLRequest
|
||||
from .. import utils
|
||||
from .. import helpers, utils
|
||||
from ..errors import (
|
||||
BadMessageError, InvalidBufferError, SecurityError,
|
||||
TypeNotFoundError, rpc_message_to_error
|
||||
|
@ -25,23 +24,6 @@ from ..crypto import AuthKey
|
|||
from ..helpers import retry_range
|
||||
|
||||
|
||||
def _cancellable(func):
|
||||
"""
|
||||
Silences `asyncio.CancelledError` for an entire function.
|
||||
|
||||
This way the function can be cancelled without the task ending
|
||||
with a exception, and without the function body requiring another
|
||||
indent level for the try/except.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return wrapped
|
||||
|
||||
|
||||
class MTProtoSender:
|
||||
"""
|
||||
MTProto Mobile Protocol sender
|
||||
|
@ -143,12 +125,12 @@ class MTProtoSender:
|
|||
def is_connected(self):
|
||||
return self._user_connected
|
||||
|
||||
def disconnect(self):
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Cleanly disconnects the instance from the network, cancels
|
||||
all pending requests, and closes the send and receive loops.
|
||||
"""
|
||||
self._disconnect()
|
||||
await self._disconnect()
|
||||
|
||||
def send(self, request, ordered=False):
|
||||
"""
|
||||
|
@ -251,7 +233,7 @@ class MTProtoSender:
|
|||
else:
|
||||
e = ConnectionError('auth_key generation failed {} time(s)'
|
||||
.format(attempt))
|
||||
self._disconnect(error=e)
|
||||
await self._disconnect(error=e)
|
||||
raise e
|
||||
|
||||
self._log.debug('Starting send loop')
|
||||
|
@ -268,12 +250,12 @@ class MTProtoSender:
|
|||
|
||||
self._log.info('Connection to %s complete!', self._connection)
|
||||
|
||||
def _disconnect(self, error=None):
|
||||
async def _disconnect(self, error=None):
|
||||
self._log.info('Disconnecting from %s...', self._connection)
|
||||
self._user_connected = False
|
||||
try:
|
||||
self._log.debug('Closing current connection...')
|
||||
self._connection.disconnect()
|
||||
await self._connection.disconnect()
|
||||
finally:
|
||||
self._log.debug('Cancelling {} pending message(s)...'
|
||||
.format(len(self._pending_state)))
|
||||
|
@ -286,14 +268,11 @@ class MTProtoSender:
|
|||
self._pending_state.clear()
|
||||
self._pending_ack.clear()
|
||||
self._last_ack = None
|
||||
|
||||
if self._send_loop_handle:
|
||||
self._log.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
if self._recv_loop_handle:
|
||||
self._log.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
await helpers._cancel(
|
||||
self._log,
|
||||
send_loop_handle=self._send_loop_handle,
|
||||
recv_loop_handle=self._recv_loop_handle
|
||||
)
|
||||
|
||||
self._log.info('Disconnection from %s complete!', self._connection)
|
||||
if self._disconnected and not self._disconnected.done():
|
||||
|
@ -309,13 +288,13 @@ class MTProtoSender:
|
|||
self._reconnecting = True
|
||||
|
||||
self._log.debug('Closing current connection...')
|
||||
self._connection.disconnect()
|
||||
await self._connection.disconnect()
|
||||
|
||||
self._log.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
self._log.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
await helpers._cancel(
|
||||
self._log,
|
||||
send_loop_handle=self._send_loop_handle,
|
||||
recv_loop_handle=self._recv_loop_handle
|
||||
)
|
||||
|
||||
self._reconnecting = False
|
||||
|
||||
|
@ -347,7 +326,7 @@ class MTProtoSender:
|
|||
else:
|
||||
self._log.error('Automatic reconnection failed {} time(s)'
|
||||
.format(attempt))
|
||||
self._disconnect(error=ConnectionError())
|
||||
await self._disconnect(error=ConnectionError())
|
||||
|
||||
def _start_reconnect(self):
|
||||
"""Starts a reconnection in the background."""
|
||||
|
@ -356,7 +335,6 @@ class MTProtoSender:
|
|||
|
||||
# Loops
|
||||
|
||||
@_cancellable
|
||||
async def _send_loop(self):
|
||||
"""
|
||||
This loop is responsible for popping items off the send
|
||||
|
@ -402,7 +380,6 @@ class MTProtoSender:
|
|||
|
||||
self._log.debug('Encrypted messages put in a queue to be sent')
|
||||
|
||||
@_cancellable
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
This loop is responsible for reading all incoming responses
|
||||
|
|
|
@ -365,7 +365,7 @@ async def main(loop, interval=0.05):
|
|||
if 'application has been destroyed' not in e.args[0]:
|
||||
raise
|
||||
finally:
|
||||
app.cl.disconnect()
|
||||
await app.cl.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue
Block a user