Use concurrent futures and threads

This commit is contained in:
Lonami Exo 2018-06-28 09:29:41 +02:00
parent 62c6565189
commit 268e43d5c3
10 changed files with 79 additions and 160 deletions

View File

@ -105,8 +105,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
max_attempts=max_attempts max_attempts=max_attempts
) )
return ( return (
coro if self.loop.is_running() coro
else self.loop.run_until_complete(coro)
) )
def _start( def _start(

View File

@ -1,13 +1,12 @@
import abc import abc
import asyncio
import inspect
import logging import logging
import platform import platform
import sys import sys
import threading
import time import time
from datetime import timedelta, datetime from datetime import timedelta, datetime
from .. import version from .. import version, syncio
from ..crypto import rsa from ..crypto import rsa
from ..extensions import markdown from ..extensions import markdown
from ..network import MTProtoSender, ConnectionTcpFull from ..network import MTProtoSender, ConnectionTcpFull
@ -154,7 +153,6 @@ class TelegramBaseClient(abc.ABC):
"Refer to telethon.rtfd.io for more information.") "Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6 self._use_ipv6 = use_ipv6
self._loop = loop or asyncio.get_event_loop()
# Determine what session object we have # Determine what session object we have
if isinstance(session, str) or session is None: if isinstance(session, str) or session is None:
@ -184,7 +182,7 @@ class TelegramBaseClient(abc.ABC):
if isinstance(connection, type): if isinstance(connection, type):
connection = connection( connection = connection(
proxy=proxy, timeout=timeout, loop=self._loop) proxy=proxy, timeout=timeout)
# Used on connection. Capture the variables in a lambda since # Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest. # exporting clients need to create this InvokeWithLayerRequest.
@ -205,7 +203,7 @@ class TelegramBaseClient(abc.ABC):
state = MTProtoState(self.session.auth_key) state = MTProtoState(self.session.auth_key)
self._connection = connection self._connection = connection
self._sender = MTProtoSender( self._sender = MTProtoSender(
state, connection, self._loop, state, connection,
retries=self._connection_retries, retries=self._connection_retries,
auto_reconnect=self._auto_reconnect, auto_reconnect=self._auto_reconnect,
update_callback=self._handle_update, update_callback=self._handle_update,
@ -235,7 +233,7 @@ class TelegramBaseClient(abc.ABC):
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
self._events_pending_resolve = [] self._events_pending_resolve = []
self._event_resolve_lock = asyncio.Lock() self._event_resolve_lock = threading.Lock()
# Default parse mode # Default parse mode
self._parse_mode = markdown self._parse_mode = markdown
@ -253,10 +251,6 @@ class TelegramBaseClient(abc.ABC):
# region Properties # region Properties
@property
def loop(self):
return self._loop
@property @property
def disconnected(self): def disconnected(self):
""" """
@ -279,7 +273,7 @@ class TelegramBaseClient(abc.ABC):
self._sender.send(self._init_with( self._sender.send(self._init_with(
functions.help.GetConfigRequest())) functions.help.GetConfigRequest()))
self._updates_handle = self._loop.create_task(self._update_loop()) self._updates_handle = syncio.create_task(self._update_loop)
def is_connected(self): def is_connected(self):
""" """
@ -307,22 +301,14 @@ class TelegramBaseClient(abc.ABC):
if getattr(self, '_sender', None): if getattr(self, '_sender', None):
self._sender.disconnect() self._sender.disconnect()
if getattr(self, '_updates_handle', None): if getattr(self, '_updates_handle', None):
self._updates_handle if threading.current_thread() != self._updates_handle:
self._updates_handle.join()
def __del__(self): def __del__(self):
if not self.is_connected() or self.loop.is_closed(): if not self.is_connected():
return return
# Python 3.5.2's ``asyncio`` mod seems to have a bug where it's not self.disconnect()
# able to close the pending tasks properly, and letting the script
# complete without calling disconnect causes the script to trigger
# 100% CPU load. Call disconnect to make sure it doesn't happen.
if not inspect.iscoroutinefunction(self.disconnect):
self.disconnect()
elif self._loop.is_running():
self._loop.create_task(self.disconnect())
else:
self._loop.run_until_complete(self.disconnect())
def _switch_dc(self, new_dc): def _switch_dc(self, new_dc):
""" """
@ -384,7 +370,7 @@ class TelegramBaseClient(abc.ABC):
# #
# If one were to do that, Telegram would reset the connection # If one were to do that, Telegram would reset the connection
# with no further clues. # with no further clues.
sender = MTProtoSender(state, self._connection.clone(), self._loop) sender = MTProtoSender(state, self._connection.clone())
sender.connect(dc.ip_address, dc.port) sender.connect(dc.ip_address, dc.port)
if not auth: if not auth:
__log__.info('Exporting authorization for data center %s', dc) __log__.info('Exporting authorization for data center %s', dc)

View File

@ -25,12 +25,12 @@ class UserMethods(TelegramBaseClient):
if isinstance(future, list): if isinstance(future, list):
results = [] results = []
for f in future: for f in future:
result = f result = f.result()
self.session.process_entities(result) self.session.process_entities(result)
results.append(result) results.append(result)
return results return results
else: else:
result = future result = future.result()
self.session.process_entities(result) self.session.process_entities(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError) as e: except (errors.ServerError, errors.RpcCallFailError) as e:
@ -198,17 +198,14 @@ class UserMethods(TelegramBaseClient):
its job** and don't worry about getting the input entity first, but its job** and don't worry about getting the input entity first, but
if you're going to use an entity often, consider making the call: if you're going to use an entity often, consider making the call:
>>> import asyncio
>>> rc = asyncio.get_event_loop().run_until_complete
>>>
>>> from telethon import TelegramClient >>> from telethon import TelegramClient
>>> client = TelegramClient(...) >>> client = TelegramClient(...)
>>> # If you're going to use "username" often in your code >>> # If you're going to use "username" often in your code
>>> # (make a lot of calls), consider getting its input entity >>> # (make a lot of calls), consider getting its input entity
>>> # once, and then using the "user" everywhere instead. >>> # once, and then using the "user" everywhere instead.
>>> user = rc(client.get_input_entity('username')) >>> user = client.get_input_entity('username')
>>> # The same applies to IDs, chats or channels. >>> # The same applies to IDs, chats or channels.
>>> chat = rc(client.get_input_entity(-123456789)) >>> chat = client.get_input_entity(-123456789)
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`): entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
If an username is given, **the library will use the cache**. This If an username is given, **the library will use the cache**. This

View File

@ -7,10 +7,10 @@ may be ``await``'ed before being able to return the exact byte count.
This class is also not concerned about disconnections or retries of This class is also not concerned about disconnections or retries of
any sort, nor any other kind of errors such as connecting twice. any sort, nor any other kind of errors such as connecting twice.
""" """
import asyncio
import errno import errno
import logging import logging
import socket import socket
import threading
from io import BytesIO from io import BytesIO
CONN_RESET_ERRNOS = { CONN_RESET_ERRNOS = {
@ -37,17 +37,17 @@ class TcpClient:
class SocketClosed(ConnectionError): class SocketClosed(ConnectionError):
pass pass
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, *, timeout, proxy=None):
""" """
Initializes the TCP client. Initializes the TCP client.
:param proxy: the proxy to be used, if any. :param proxy: the proxy to be used, if any.
:param timeout: the timeout for connect, read and write operations. :param timeout: the timeout for connect, read and write operations.
""" """
self._loop = loop
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._closed = asyncio.Event(loop=self._loop)
self._closed = threading.Event()
self._closed.set() self._closed.set()
if isinstance(timeout, (int, float)): if isinstance(timeout, (int, float)):
@ -88,11 +88,8 @@ class TcpClient:
if self._socket is None: if self._socket is None:
self._socket = self._create_socket(mode, self.proxy) self._socket = self._create_socket(mode, self.proxy)
asyncio.wait_for( self._socket.settimeout(self.timeout)
self._loop.sock_connect(self._socket, address), self._socket.connect(address)
timeout=self.timeout,
loop=self._loop
)
self._closed.clear() self._closed.clear()
except OSError as e: except OSError as e:
if e.errno in CONN_RESET_ERRNOS: if e.errno in CONN_RESET_ERRNOS:
@ -107,7 +104,6 @@ class TcpClient:
def close(self): def close(self):
"""Closes the connection.""" """Closes the connection."""
fd = None
try: try:
if self._socket is not None: if self._socket is not None:
fd = self._socket.fileno() fd = self._socket.fileno()
@ -119,27 +115,6 @@ class TcpClient:
finally: finally:
self._socket = None self._socket = None
self._closed.set() self._closed.set()
if fd:
self._loop.remove_reader(fd)
def _wait_timeout_or_close(self, coro):
"""
Waits for the given coroutine to complete unless
the socket is closed or `self.timeout` expires.
"""
done, running = asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
for r in running:
r.cancel()
if not self.is_connected:
raise self.SocketClosed()
if not done:
raise asyncio.TimeoutError()
return done.pop().result()
def write(self, data): def write(self, data):
""" """
@ -150,7 +125,7 @@ class TcpClient:
raise ConnectionResetError('Not connected') raise ConnectionResetError('Not connected')
try: try:
self._wait_timeout_or_close(self.sock_sendall(data)) self.sock_sendall(data)
except OSError as e: except OSError as e:
if e.errno in CONN_RESET_ERRNOS: if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e raise ConnectionResetError() from e
@ -171,10 +146,8 @@ class TcpClient:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
try: try:
partial = self._wait_timeout_or_close( partial = self.sock_recv(bytes_left)
self.sock_recv(bytes_left) except socket.timeout:
)
except asyncio.TimeoutError:
if bytes_left < size: if bytes_left < size:
__log__.warning( __log__.warning(
'Timeout when partial %d/%d had been received', 'Timeout when partial %d/%d had been received',
@ -195,55 +168,8 @@ class TcpClient:
return buffer.getvalue() return buffer.getvalue()
# Due to recent https://github.com/python/cpython/pull/4386
# Credit to @andr-04 for his original implementation
def sock_recv(self, n): def sock_recv(self, n):
fut = self._loop.create_future() return self._socket.recv(n)
self._sock_recv(fut, None, n)
return fut
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled() or self._socket is None:
return
try:
data = self._socket.recv(n)
except (BlockingIOError, InterruptedError):
fd = self._socket.fileno()
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(data)
def sock_sendall(self, data): def sock_sendall(self, data):
fut = self._loop.create_future() return self._socket.sendall(data)
if data:
self._sock_sendall(fut, None, data)
else:
fut.set_result(None)
return fut
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled() or self._socket is None:
return
try:
n = self._socket.send(data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = self._socket.fileno()
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)

View File

@ -20,15 +20,13 @@ class Connection(abc.ABC):
Subclasses should implement the actual protocol Subclasses should implement the actual protocol
being used when encoding/decoding messages. being used when encoding/decoding messages.
""" """
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, *, timeout, proxy=None):
""" """
Initializes a new connection. Initializes a new connection.
:param loop: the event loop to be used.
:param timeout: timeout to be used for all operations. :param timeout: timeout to be used for all operations.
:param proxy: whether to use a proxy or not. :param proxy: whether to use a proxy or not.
""" """
self._loop = loop
self._proxy = proxy self._proxy = proxy
self._timeout = timeout self._timeout = timeout
@ -58,7 +56,6 @@ class Connection(abc.ABC):
def clone(self): def clone(self):
"""Creates a copy of this Connection.""" """Creates a copy of this Connection."""
return self.__class__( return self.__class__(
loop=self._loop,
proxy=self._proxy, proxy=self._proxy,
timeout=self._timeout timeout=self._timeout
) )

View File

@ -12,11 +12,11 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself. needs to calculate the CRC value of the packet itself.
""" """
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, *, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy) super().__init__(timeout=timeout, proxy=proxy)
self._send_counter = 0 self._send_counter = 0
self.conn = TcpClient( self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy timeout=self._timeout, proxy=self._proxy
) )
self.read = self.conn.read self.read = self.conn.read
self.write = self.conn.write self.write = self.conn.write

View File

@ -11,8 +11,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern. AES-CTR mode so the packets are harder to discern.
""" """
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, *, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy) super().__init__(timeout=timeout, proxy=proxy)
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s)) self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d)) self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))

View File

@ -1,8 +1,11 @@
import asyncio import concurrent.futures
import logging import logging
import queue
import socket
import time
from . import MTProtoPlainSender, authenticator from . import MTProtoPlainSender, authenticator
from .. import utils from .. import syncio, utils
from ..errors import ( from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError, BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
rpc_message_to_error rpc_message_to_error
@ -40,12 +43,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other A new authorization key will be generated on connection if no other
key exists yet. key exists yet.
""" """
def __init__(self, state, connection, loop, *, def __init__(self, state, connection, *,
retries=5, auto_reconnect=True, update_callback=None, retries=5, auto_reconnect=True, update_callback=None,
auth_key_callback=None, auto_reconnect_callback=None): auth_key_callback=None, auto_reconnect_callback=None):
self.state = state self.state = state
self._connection = connection self._connection = connection
self._loop = loop
self._ip = None self._ip = None
self._port = None self._port = None
self._retries = retries self._retries = retries
@ -160,11 +162,11 @@ class MTProtoSender:
if self._send_loop_handle: if self._send_loop_handle:
__log__.debug('Cancelling the send loop...') __log__.debug('Cancelling the send loop...')
self._send_loop_handle.cancel() self._send_loop_handle.join()
if self._recv_loop_handle: if self._recv_loop_handle:
__log__.debug('Cancelling the receive loop...') __log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel() self._recv_loop_handle.join()
__log__.info('Disconnection from {} complete!'.format(self._ip)) __log__.info('Disconnection from {} complete!'.format(self._ip))
if self._disconnected: if self._disconnected:
@ -224,7 +226,7 @@ class MTProtoSender:
ends, either by user action or in the background. ends, either by user action or in the background.
""" """
if self._disconnected is not None: if self._disconnected is not None:
return asyncio.shield(self._disconnected, loop=self._loop) return self._disconnected
else: else:
raise ConnectionError('Sender was never connected') raise ConnectionError('Sender was never connected')
@ -241,7 +243,7 @@ class MTProtoSender:
try: try:
__log__.debug('Connection attempt {}...'.format(retry)) __log__.debug('Connection attempt {}...'.format(retry))
self._connection.connect(self._ip, self._port) self._connection.connect(self._ip, self._port)
except (asyncio.TimeoutError, OSError) as e: except (socket.timeout, OSError) as e:
__log__.warning('Attempt {} at connecting failed: {}: {}' __log__.warning('Attempt {} at connecting failed: {}: {}'
.format(retry, type(e).__name__, e)) .format(retry, type(e).__name__, e))
else: else:
@ -273,14 +275,14 @@ class MTProtoSender:
raise e raise e
__log__.debug('Starting send loop') __log__.debug('Starting send loop')
self._send_loop_handle = self._loop.create_task(self._send_loop()) self._send_loop_handle = syncio.create_task(self._send_loop)
__log__.debug('Starting receive loop') __log__.debug('Starting receive loop')
self._recv_loop_handle = self._loop.create_task(self._recv_loop()) self._recv_loop_handle = syncio.create_task(self._recv_loop)
# First connection or manual reconnection after a failure # First connection or manual reconnection after a failure
if self._disconnected is None or self._disconnected.done(): if self._disconnected is None or self._disconnected.done():
self._disconnected = asyncio.Future() self._disconnected = concurrent.futures.Future()
__log__.info('Connection to {} complete!'.format(self._ip)) __log__.info('Connection to {} complete!'.format(self._ip))
def _reconnect(self): def _reconnect(self):
@ -291,10 +293,10 @@ class MTProtoSender:
self._send_queue.put_nowait(_reconnect_sentinel) self._send_queue.put_nowait(_reconnect_sentinel)
__log__.debug('Awaiting for the send loop before reconnecting...') __log__.debug('Awaiting for the send loop before reconnecting...')
self._send_loop_handle self._send_loop_handle.join()
__log__.debug('Awaiting for the receive loop before reconnecting...') __log__.debug('Awaiting for the receive loop before reconnecting...')
self._recv_loop_handle self._recv_loop_handle.join()
__log__.debug('Closing current connection...') __log__.debug('Closing current connection...')
self._connection.close() self._connection.close()
@ -309,7 +311,7 @@ class MTProtoSender:
self._send_queue.put_nowait(m) self._send_queue.put_nowait(m)
if self._auto_reconnect_callback: if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback()) syncio.create_task(self._auto_reconnect_callback)
break break
except ConnectionError: except ConnectionError:
@ -321,7 +323,7 @@ class MTProtoSender:
def _start_reconnect(self): def _start_reconnect(self):
"""Starts a reconnection in the background.""" """Starts a reconnection in the background."""
if self._user_connected: if self._user_connected:
self._loop.create_task(self._reconnect()) syncio.create_task(self._reconnect)
def _clean_containers(self, msg_ids): def _clean_containers(self, msg_ids):
""" """
@ -357,7 +359,11 @@ class MTProtoSender:
self._send_queue.put_nowait(self._last_ack) self._send_queue.put_nowait(self._last_ack)
self._pending_ack.clear() self._pending_ack.clear()
messages = self._send_queue.get() try:
messages = self._send_queue.get(timeout=1)
except queue.Empty:
continue
if messages == _reconnect_sentinel: if messages == _reconnect_sentinel:
if self._reconnecting: if self._reconnecting:
break break
@ -383,9 +389,9 @@ class MTProtoSender:
__log__.debug('Sending {} bytes...'.format(len(body))) __log__.debug('Sending {} bytes...'.format(len(body)))
self._connection.send(body) self._connection.send(body)
break break
except asyncio.TimeoutError: except socket.timeout:
continue continue
except asyncio.CancelledError: except concurrent.futures.CancelledError:
return return
except Exception as e: except Exception as e:
if isinstance(e, ConnectionError): if isinstance(e, ConnectionError):
@ -394,7 +400,7 @@ class MTProtoSender:
__log__.warning('OSError while sending %s', e) __log__.warning('OSError while sending %s', e)
else: else:
__log__.exception('Unhandled exception while receiving') __log__.exception('Unhandled exception while receiving')
asyncio.sleep(1) time.sleep(1)
self._start_reconnect() self._start_reconnect()
break break
@ -422,9 +428,9 @@ class MTProtoSender:
try: try:
__log__.debug('Receiving items from the network...') __log__.debug('Receiving items from the network...')
body = self._connection.recv() body = self._connection.recv()
except asyncio.TimeoutError: except socket.timeout:
continue continue
except asyncio.CancelledError: except concurrent.futures.CancelledError:
return return
except Exception as e: except Exception as e:
if isinstance(e, ConnectionError): if isinstance(e, ConnectionError):
@ -433,7 +439,7 @@ class MTProtoSender:
__log__.warning('OSError while receiving %s', e) __log__.warning('OSError while receiving %s', e)
else: else:
__log__.exception('Unhandled exception while receiving') __log__.exception('Unhandled exception while receiving')
asyncio.sleep(1) time.sleep(1)
self._start_reconnect() self._start_reconnect()
break break
@ -469,16 +475,16 @@ class MTProtoSender:
continue continue
except: except:
__log__.exception('Unhandled exception while unpacking') __log__.exception('Unhandled exception while unpacking')
asyncio.sleep(1) time.sleep(1)
else: else:
try: try:
self._process_message(message) self._process_message(message)
except asyncio.CancelledError: except concurrent.futures.CancelledError:
return return
except: except:
__log__.exception('Unhandled exception while ' __log__.exception('Unhandled exception while '
'processing %s', message) 'processing %s', message)
asyncio.sleep(1) time.sleep(1)
# Response Handlers # Response Handlers
@ -730,19 +736,19 @@ class MTProtoSender:
""" """
class _ContainerQueue(asyncio.Queue): class _ContainerQueue(queue.Queue):
""" """
An asyncio queue that's aware of `MessageContainer` instances. A queue.Queue that's aware of `MessageContainer` instances.
The `get` method returns either a single `TLMessage` or a list The `get` method returns either a single `TLMessage` or a list
of them that should be turned into a new `MessageContainer`. of them that should be turned into a new `MessageContainer`.
Instances of this class can be replaced with the simpler Instances of this class can be replaced with the simpler
``asyncio.Queue`` when needed for testing purposes, and ``queue.Queue`` when needed for testing purposes, and
a list won't be returned in said case. a list won't be returned in said case.
""" """
def get(self): def get(self, block=True, timeout=None):
result = super().get() result = super().get(block=block, timeout=timeout)
if self.empty() or result == _reconnect_sentinel or\ if self.empty() or result == _reconnect_sentinel or\
isinstance(result.obj, MessageContainer): isinstance(result.obj, MessageContainer):
return result return result

8
telethon/syncio.py Normal file
View File

@ -0,0 +1,8 @@
import threading
def create_task(method, *args, **kwargs):
thread = threading.Thread(target=method, daemon=True,
args=args, kwargs=kwargs)
thread.start()
return thread

View File

@ -1,4 +1,4 @@
import asyncio import concurrent.futures
import struct import struct
from .gzippacked import GzipPacked from .gzippacked import GzipPacked
@ -26,7 +26,7 @@ class TLMessage(TLObject):
self.seq_no = seq_no self.seq_no = seq_no
self.obj = obj self.obj = obj
self.container_msg_id = None self.container_msg_id = None
self.future = asyncio.Future() self.future = concurrent.futures.Future()
# After which message ID this one should run. We do this so # After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can # InvokeAfterMsgRequest is transparent to the user and we can