mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-10 19:46:36 +03:00
Revert disconnect() to be async again (#1133)
It's the only way to properly clean all background tasks, which the library makes heavy use for in MTProto/Connection send and receive loops. Some parts of the code even relied on the fact that it was asynchronous (it used to return a future so you could await it and not be breaking). It's automatically syncified to reduce the damage of being a breaking change.
This commit is contained in:
parent
8f302bcdb0
commit
04ba2e1fc7
|
@ -440,7 +440,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
|
||||||
self._state = types.updates.State(
|
self._state = types.updates.State(
|
||||||
0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
|
0, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
|
||||||
|
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
self.session.delete()
|
self.session.delete()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -550,9 +550,9 @@ class AuthMethods(MessageParseMethods, UserMethods):
|
||||||
return await self.start()
|
return await self.start()
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self.disconnect()
|
self.disconnect() # It's also syncified, like start()
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -281,7 +281,7 @@ class DownloadMethods(UserMethods):
|
||||||
if exported:
|
if exported:
|
||||||
await self._return_exported_sender(sender)
|
await self._return_exported_sender(sender)
|
||||||
elif sender != self._sender:
|
elif sender != self._sender:
|
||||||
sender.disconnect()
|
await sender.disconnect()
|
||||||
if isinstance(file, str) or in_memory:
|
if isinstance(file, str) or in_memory:
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from .. import version, __name__ as __base_name__
|
from .. import version, helpers, __name__ as __base_name__
|
||||||
from ..crypto import rsa
|
from ..crypto import rsa
|
||||||
from ..extensions import markdown
|
from ..extensions import markdown
|
||||||
from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
|
from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
|
||||||
|
@ -376,28 +376,29 @@ class TelegramBaseClient(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Disconnects from Telegram.
|
Disconnects from Telegram.
|
||||||
|
|
||||||
Returns a dummy completed future with ``None`` as a result so
|
If the event loop is already running, this method returns a
|
||||||
you can ``await`` this method just like every other method for
|
coroutine that you should await on your own code; otherwise
|
||||||
consistency or compatibility.
|
the loop is ran until said coroutine completes.
|
||||||
"""
|
"""
|
||||||
self._disconnect()
|
if self._loop.is_running():
|
||||||
|
return self._disconnect_coro()
|
||||||
|
else:
|
||||||
|
self._loop.run_until_complete(self._disconnect_coro())
|
||||||
|
|
||||||
|
async def _disconnect_coro(self):
|
||||||
|
await self._disconnect()
|
||||||
self.session.set_update_state(0, self._state)
|
self.session.set_update_state(0, self._state)
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
|
||||||
result = self._loop.create_future()
|
async def _disconnect(self):
|
||||||
result.set_result(None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _disconnect(self):
|
|
||||||
"""
|
"""
|
||||||
Disconnect only, without closing the session. Used in reconnections
|
Disconnect only, without closing the session. Used in reconnections
|
||||||
to different data centers, where we don't want to close the session
|
to different data centers, where we don't want to close the session
|
||||||
file; user disconnects however should close it since it means that
|
file; user disconnects however should close it since it means that
|
||||||
their job with the client is complete and we should clean it up all.
|
their job with the client is complete and we should clean it up all.
|
||||||
"""
|
"""
|
||||||
self._sender.disconnect()
|
await self._sender.disconnect()
|
||||||
if self._updates_handle:
|
await helpers._cancel(self._log, updates_handle=self._updates_handle)
|
||||||
self._updates_handle.cancel()
|
|
||||||
|
|
||||||
async def _switch_dc(self, new_dc):
|
async def _switch_dc(self, new_dc):
|
||||||
"""
|
"""
|
||||||
|
@ -412,7 +413,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
self._sender.auth_key.key = None
|
self._sender.auth_key.key = None
|
||||||
self.session.auth_key = None
|
self.session.auth_key = None
|
||||||
self.session.save()
|
self.session.save()
|
||||||
self._disconnect()
|
await self._disconnect()
|
||||||
return await self.connect()
|
return await self.connect()
|
||||||
|
|
||||||
def _auth_key_callback(self, auth_key):
|
def _auth_key_callback(self, auth_key):
|
||||||
|
@ -515,7 +516,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
if not n:
|
if not n:
|
||||||
self._log[__name__].info(
|
self._log[__name__].info(
|
||||||
'Disconnecting borrowed sender for DC %d', dc_id)
|
'Disconnecting borrowed sender for DC %d', dc_id)
|
||||||
sender.disconnect()
|
await sender.disconnect()
|
||||||
|
|
||||||
async def _get_cdn_client(self, cdn_redirect):
|
async def _get_cdn_client(self, cdn_redirect):
|
||||||
"""Similar to ._borrow_exported_client, but for CDNs"""
|
"""Similar to ._borrow_exported_client, but for CDNs"""
|
||||||
|
|
|
@ -16,7 +16,7 @@ class UpdateMethods(UserMethods):
|
||||||
try:
|
try:
|
||||||
await self.disconnected
|
await self.disconnected
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
def run_until_disconnected(self):
|
def run_until_disconnected(self):
|
||||||
"""
|
"""
|
||||||
|
@ -33,7 +33,7 @@ class UpdateMethods(UserMethods):
|
||||||
try:
|
try:
|
||||||
return self.loop.run_until_complete(self.disconnected)
|
return self.loop.run_until_complete(self.disconnected)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.disconnect()
|
self.loop.run_until_complete(self.disconnect())
|
||||||
|
|
||||||
def on(self, event):
|
def on(self, event):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Various helpers not related to the Telegram API itself"""
|
"""Various helpers not related to the Telegram API itself"""
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from hashlib import sha1, sha256
|
from hashlib import sha1, sha256
|
||||||
|
@ -85,6 +86,22 @@ def retry_range(retries):
|
||||||
yield 1 + attempt
|
yield 1 + attempt
|
||||||
|
|
||||||
|
|
||||||
|
async def _cancel(log, **tasks):
|
||||||
|
"""
|
||||||
|
Helper to cancel one or more tasks gracefully, logging exceptions.
|
||||||
|
"""
|
||||||
|
for name, task in tasks.items():
|
||||||
|
if not task:
|
||||||
|
continue
|
||||||
|
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
log.exception('Unhandled exception from %s after cancel', name)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Cryptographic related utils
|
# region Cryptographic related utils
|
||||||
|
|
|
@ -4,6 +4,7 @@ import socket
|
||||||
import ssl as ssl_mod
|
import ssl as ssl_mod
|
||||||
|
|
||||||
from ...errors import InvalidChecksumError
|
from ...errors import InvalidChecksumError
|
||||||
|
from ... import helpers
|
||||||
|
|
||||||
|
|
||||||
class Connection(abc.ABC):
|
class Connection(abc.ABC):
|
||||||
|
@ -92,18 +93,18 @@ class Connection(abc.ABC):
|
||||||
self._send_task = self._loop.create_task(self._send_loop())
|
self._send_task = self._loop.create_task(self._send_loop())
|
||||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||||
|
|
||||||
def disconnect(self):
|
async def disconnect(self):
|
||||||
"""
|
"""
|
||||||
Disconnects from the server, and clears
|
Disconnects from the server, and clears
|
||||||
pending outgoing and incoming messages.
|
pending outgoing and incoming messages.
|
||||||
"""
|
"""
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
if self._send_task:
|
await helpers._cancel(
|
||||||
self._send_task.cancel()
|
self._log,
|
||||||
|
send_task=self._send_task,
|
||||||
if self._recv_task:
|
recv_task=self._recv_task
|
||||||
self._recv_task.cancel()
|
)
|
||||||
|
|
||||||
if self._writer:
|
if self._writer:
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
@ -148,7 +149,7 @@ class Connection(abc.ABC):
|
||||||
else:
|
else:
|
||||||
self._log.exception('Unexpected exception in the send loop')
|
self._log.exception('Unexpected exception in the send loop')
|
||||||
|
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
async def _recv_loop(self):
|
async def _recv_loop(self):
|
||||||
"""
|
"""
|
||||||
|
@ -170,7 +171,7 @@ class Connection(abc.ABC):
|
||||||
msg = 'Unexpected exception in the receive loop'
|
msg = 'Unexpected exception in the receive loop'
|
||||||
self._log.exception(msg)
|
self._log.exception(msg)
|
||||||
|
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
# Add a sentinel value to unstuck recv
|
# Add a sentinel value to unstuck recv
|
||||||
if self._recv_queue.empty():
|
if self._recv_queue.empty():
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TcpMTProxy(ObfuscatedConnection):
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
if self._reader.at_eof():
|
if self._reader.at_eof():
|
||||||
self.disconnect()
|
await self.disconnect()
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
'Proxy closed the connection after sending initial payload')
|
'Proxy closed the connection after sending initial payload')
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import functools
|
|
||||||
|
|
||||||
from . import authenticator
|
from . import authenticator
|
||||||
from ..extensions.messagepacker import MessagePacker
|
from ..extensions.messagepacker import MessagePacker
|
||||||
|
@ -8,7 +7,7 @@ from .mtprotoplainsender import MTProtoPlainSender
|
||||||
from .requeststate import RequestState
|
from .requeststate import RequestState
|
||||||
from .mtprotostate import MTProtoState
|
from .mtprotostate import MTProtoState
|
||||||
from ..tl.tlobject import TLRequest
|
from ..tl.tlobject import TLRequest
|
||||||
from .. import utils
|
from .. import helpers, utils
|
||||||
from ..errors import (
|
from ..errors import (
|
||||||
BadMessageError, InvalidBufferError, SecurityError,
|
BadMessageError, InvalidBufferError, SecurityError,
|
||||||
TypeNotFoundError, rpc_message_to_error
|
TypeNotFoundError, rpc_message_to_error
|
||||||
|
@ -25,23 +24,6 @@ from ..crypto import AuthKey
|
||||||
from ..helpers import retry_range
|
from ..helpers import retry_range
|
||||||
|
|
||||||
|
|
||||||
def _cancellable(func):
|
|
||||||
"""
|
|
||||||
Silences `asyncio.CancelledError` for an entire function.
|
|
||||||
|
|
||||||
This way the function can be cancelled without the task ending
|
|
||||||
with a exception, and without the function body requiring another
|
|
||||||
indent level for the try/except.
|
|
||||||
"""
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapped(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
class MTProtoSender:
|
class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
MTProto Mobile Protocol sender
|
MTProto Mobile Protocol sender
|
||||||
|
@ -143,12 +125,12 @@ class MTProtoSender:
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self._user_connected
|
return self._user_connected
|
||||||
|
|
||||||
def disconnect(self):
|
async def disconnect(self):
|
||||||
"""
|
"""
|
||||||
Cleanly disconnects the instance from the network, cancels
|
Cleanly disconnects the instance from the network, cancels
|
||||||
all pending requests, and closes the send and receive loops.
|
all pending requests, and closes the send and receive loops.
|
||||||
"""
|
"""
|
||||||
self._disconnect()
|
await self._disconnect()
|
||||||
|
|
||||||
def send(self, request, ordered=False):
|
def send(self, request, ordered=False):
|
||||||
"""
|
"""
|
||||||
|
@ -251,7 +233,7 @@ class MTProtoSender:
|
||||||
else:
|
else:
|
||||||
e = ConnectionError('auth_key generation failed {} time(s)'
|
e = ConnectionError('auth_key generation failed {} time(s)'
|
||||||
.format(attempt))
|
.format(attempt))
|
||||||
self._disconnect(error=e)
|
await self._disconnect(error=e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
self._log.debug('Starting send loop')
|
self._log.debug('Starting send loop')
|
||||||
|
@ -268,12 +250,12 @@ class MTProtoSender:
|
||||||
|
|
||||||
self._log.info('Connection to %s complete!', self._connection)
|
self._log.info('Connection to %s complete!', self._connection)
|
||||||
|
|
||||||
def _disconnect(self, error=None):
|
async def _disconnect(self, error=None):
|
||||||
self._log.info('Disconnecting from %s...', self._connection)
|
self._log.info('Disconnecting from %s...', self._connection)
|
||||||
self._user_connected = False
|
self._user_connected = False
|
||||||
try:
|
try:
|
||||||
self._log.debug('Closing current connection...')
|
self._log.debug('Closing current connection...')
|
||||||
self._connection.disconnect()
|
await self._connection.disconnect()
|
||||||
finally:
|
finally:
|
||||||
self._log.debug('Cancelling {} pending message(s)...'
|
self._log.debug('Cancelling {} pending message(s)...'
|
||||||
.format(len(self._pending_state)))
|
.format(len(self._pending_state)))
|
||||||
|
@ -286,14 +268,11 @@ class MTProtoSender:
|
||||||
self._pending_state.clear()
|
self._pending_state.clear()
|
||||||
self._pending_ack.clear()
|
self._pending_ack.clear()
|
||||||
self._last_ack = None
|
self._last_ack = None
|
||||||
|
await helpers._cancel(
|
||||||
if self._send_loop_handle:
|
self._log,
|
||||||
self._log.debug('Cancelling the send loop...')
|
send_loop_handle=self._send_loop_handle,
|
||||||
self._send_loop_handle.cancel()
|
recv_loop_handle=self._recv_loop_handle
|
||||||
|
)
|
||||||
if self._recv_loop_handle:
|
|
||||||
self._log.debug('Cancelling the receive loop...')
|
|
||||||
self._recv_loop_handle.cancel()
|
|
||||||
|
|
||||||
self._log.info('Disconnection from %s complete!', self._connection)
|
self._log.info('Disconnection from %s complete!', self._connection)
|
||||||
if self._disconnected and not self._disconnected.done():
|
if self._disconnected and not self._disconnected.done():
|
||||||
|
@ -309,13 +288,13 @@ class MTProtoSender:
|
||||||
self._reconnecting = True
|
self._reconnecting = True
|
||||||
|
|
||||||
self._log.debug('Closing current connection...')
|
self._log.debug('Closing current connection...')
|
||||||
self._connection.disconnect()
|
await self._connection.disconnect()
|
||||||
|
|
||||||
self._log.debug('Cancelling the send loop...')
|
await helpers._cancel(
|
||||||
self._send_loop_handle.cancel()
|
self._log,
|
||||||
|
send_loop_handle=self._send_loop_handle,
|
||||||
self._log.debug('Cancelling the receive loop...')
|
recv_loop_handle=self._recv_loop_handle
|
||||||
self._recv_loop_handle.cancel()
|
)
|
||||||
|
|
||||||
self._reconnecting = False
|
self._reconnecting = False
|
||||||
|
|
||||||
|
@ -347,7 +326,7 @@ class MTProtoSender:
|
||||||
else:
|
else:
|
||||||
self._log.error('Automatic reconnection failed {} time(s)'
|
self._log.error('Automatic reconnection failed {} time(s)'
|
||||||
.format(attempt))
|
.format(attempt))
|
||||||
self._disconnect(error=ConnectionError())
|
await self._disconnect(error=ConnectionError())
|
||||||
|
|
||||||
def _start_reconnect(self):
|
def _start_reconnect(self):
|
||||||
"""Starts a reconnection in the background."""
|
"""Starts a reconnection in the background."""
|
||||||
|
@ -356,7 +335,6 @@ class MTProtoSender:
|
||||||
|
|
||||||
# Loops
|
# Loops
|
||||||
|
|
||||||
@_cancellable
|
|
||||||
async def _send_loop(self):
|
async def _send_loop(self):
|
||||||
"""
|
"""
|
||||||
This loop is responsible for popping items off the send
|
This loop is responsible for popping items off the send
|
||||||
|
@ -402,7 +380,6 @@ class MTProtoSender:
|
||||||
|
|
||||||
self._log.debug('Encrypted messages put in a queue to be sent')
|
self._log.debug('Encrypted messages put in a queue to be sent')
|
||||||
|
|
||||||
@_cancellable
|
|
||||||
async def _recv_loop(self):
|
async def _recv_loop(self):
|
||||||
"""
|
"""
|
||||||
This loop is responsible for reading all incoming responses
|
This loop is responsible for reading all incoming responses
|
||||||
|
|
|
@ -365,7 +365,7 @@ async def main(loop, interval=0.05):
|
||||||
if 'application has been destroyed' not in e.args[0]:
|
if 'application has been destroyed' not in e.args[0]:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
app.cl.disconnect()
|
await app.cl.disconnect()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user