Use concurrent futures and threads

This commit is contained in:
Lonami Exo 2018-06-28 09:29:41 +02:00
parent 62c6565189
commit 268e43d5c3
10 changed files with 79 additions and 160 deletions

View File

@ -105,8 +105,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
max_attempts=max_attempts
)
return (
coro if self.loop.is_running()
else self.loop.run_until_complete(coro)
coro
)
def _start(

View File

@ -1,13 +1,12 @@
import abc
import asyncio
import inspect
import logging
import platform
import sys
import threading
import time
from datetime import timedelta, datetime
from .. import version
from .. import version, syncio
from ..crypto import rsa
from ..extensions import markdown
from ..network import MTProtoSender, ConnectionTcpFull
@ -154,7 +153,6 @@ class TelegramBaseClient(abc.ABC):
"Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6
self._loop = loop or asyncio.get_event_loop()
# Determine what session object we have
if isinstance(session, str) or session is None:
@ -184,7 +182,7 @@ class TelegramBaseClient(abc.ABC):
if isinstance(connection, type):
connection = connection(
proxy=proxy, timeout=timeout, loop=self._loop)
proxy=proxy, timeout=timeout)
# Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest.
@ -205,7 +203,7 @@ class TelegramBaseClient(abc.ABC):
state = MTProtoState(self.session.auth_key)
self._connection = connection
self._sender = MTProtoSender(
state, connection, self._loop,
state, connection,
retries=self._connection_retries,
auto_reconnect=self._auto_reconnect,
update_callback=self._handle_update,
@ -235,7 +233,7 @@ class TelegramBaseClient(abc.ABC):
# Some further state for subclasses
self._event_builders = []
self._events_pending_resolve = []
self._event_resolve_lock = asyncio.Lock()
self._event_resolve_lock = threading.Lock()
# Default parse mode
self._parse_mode = markdown
@ -253,10 +251,6 @@ class TelegramBaseClient(abc.ABC):
# region Properties
@property
def loop(self):
return self._loop
@property
def disconnected(self):
"""
@ -279,7 +273,7 @@ class TelegramBaseClient(abc.ABC):
self._sender.send(self._init_with(
functions.help.GetConfigRequest()))
self._updates_handle = self._loop.create_task(self._update_loop())
self._updates_handle = syncio.create_task(self._update_loop)
def is_connected(self):
"""
@ -307,22 +301,14 @@ class TelegramBaseClient(abc.ABC):
if getattr(self, '_sender', None):
self._sender.disconnect()
if getattr(self, '_updates_handle', None):
self._updates_handle
if threading.current_thread() != self._updates_handle:
self._updates_handle.join()
def __del__(self):
if not self.is_connected() or self.loop.is_closed():
if not self.is_connected():
return
# Python 3.5.2's ``asyncio`` mod seems to have a bug where it's not
# able to close the pending tasks properly, and letting the script
# complete without calling disconnect causes the script to trigger
# 100% CPU load. Call disconnect to make sure it doesn't happen.
if not inspect.iscoroutinefunction(self.disconnect):
self.disconnect()
elif self._loop.is_running():
self._loop.create_task(self.disconnect())
else:
self._loop.run_until_complete(self.disconnect())
self.disconnect()
def _switch_dc(self, new_dc):
"""
@ -384,7 +370,7 @@ class TelegramBaseClient(abc.ABC):
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(state, self._connection.clone(), self._loop)
sender = MTProtoSender(state, self._connection.clone())
sender.connect(dc.ip_address, dc.port)
if not auth:
__log__.info('Exporting authorization for data center %s', dc)

View File

@ -25,12 +25,12 @@ class UserMethods(TelegramBaseClient):
if isinstance(future, list):
results = []
for f in future:
result = f
result = f.result()
self.session.process_entities(result)
results.append(result)
return results
else:
result = future
result = future.result()
self.session.process_entities(result)
return result
except (errors.ServerError, errors.RpcCallFailError) as e:
@ -198,17 +198,14 @@ class UserMethods(TelegramBaseClient):
its job** and don't worry about getting the input entity first, but
if you're going to use an entity often, consider making the call:
>>> import asyncio
>>> rc = asyncio.get_event_loop().run_until_complete
>>>
>>> from telethon import TelegramClient
>>> client = TelegramClient(...)
>>> # If you're going to use "username" often in your code
>>> # (make a lot of calls), consider getting its input entity
>>> # once, and then using the "user" everywhere instead.
>>> user = rc(client.get_input_entity('username'))
>>> user = client.get_input_entity('username')
>>> # The same applies to IDs, chats or channels.
>>> chat = rc(client.get_input_entity(-123456789))
>>> chat = client.get_input_entity(-123456789)
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
If an username is given, **the library will use the cache**. This

View File

@ -7,10 +7,10 @@ may be ``await``'ed before being able to return the exact byte count.
This class is also not concerned about disconnections or retries of
any sort, nor any other kind of errors such as connecting twice.
"""
import asyncio
import errno
import logging
import socket
import threading
from io import BytesIO
CONN_RESET_ERRNOS = {
@ -37,17 +37,17 @@ class TcpClient:
class SocketClosed(ConnectionError):
pass
def __init__(self, *, loop, timeout, proxy=None):
def __init__(self, *, timeout, proxy=None):
"""
Initializes the TCP client.
:param proxy: the proxy to be used, if any.
:param timeout: the timeout for connect, read and write operations.
"""
self._loop = loop
self.proxy = proxy
self._socket = None
self._closed = asyncio.Event(loop=self._loop)
self._closed = threading.Event()
self._closed.set()
if isinstance(timeout, (int, float)):
@ -88,11 +88,8 @@ class TcpClient:
if self._socket is None:
self._socket = self._create_socket(mode, self.proxy)
asyncio.wait_for(
self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
self._socket.settimeout(self.timeout)
self._socket.connect(address)
self._closed.clear()
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
@ -107,7 +104,6 @@ class TcpClient:
def close(self):
"""Closes the connection."""
fd = None
try:
if self._socket is not None:
fd = self._socket.fileno()
@ -119,27 +115,6 @@ class TcpClient:
finally:
self._socket = None
self._closed.set()
if fd:
self._loop.remove_reader(fd)
def _wait_timeout_or_close(self, coro):
"""
Waits for the given coroutine to complete unless
the socket is closed or `self.timeout` expires.
"""
done, running = asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
for r in running:
r.cancel()
if not self.is_connected:
raise self.SocketClosed()
if not done:
raise asyncio.TimeoutError()
return done.pop().result()
def write(self, data):
"""
@ -150,7 +125,7 @@ class TcpClient:
raise ConnectionResetError('Not connected')
try:
self._wait_timeout_or_close(self.sock_sendall(data))
self.sock_sendall(data)
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e
@ -171,10 +146,8 @@ class TcpClient:
bytes_left = size
while bytes_left != 0:
try:
partial = self._wait_timeout_or_close(
self.sock_recv(bytes_left)
)
except asyncio.TimeoutError:
partial = self.sock_recv(bytes_left)
except socket.timeout:
if bytes_left < size:
__log__.warning(
'Timeout when partial %d/%d had been received',
@ -195,55 +168,8 @@ class TcpClient:
return buffer.getvalue()
# Due to recent https://github.com/python/cpython/pull/4386
# Credit to @andr-04 for his original implementation
def sock_recv(self, n):
fut = self._loop.create_future()
self._sock_recv(fut, None, n)
return fut
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled() or self._socket is None:
return
try:
data = self._socket.recv(n)
except (BlockingIOError, InterruptedError):
fd = self._socket.fileno()
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(data)
return self._socket.recv(n)
def sock_sendall(self, data):
fut = self._loop.create_future()
if data:
self._sock_sendall(fut, None, data)
else:
fut.set_result(None)
return fut
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled() or self._socket is None:
return
try:
n = self._socket.send(data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = self._socket.fileno()
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)
return self._socket.sendall(data)

View File

@ -20,15 +20,13 @@ class Connection(abc.ABC):
Subclasses should implement the actual protocol
being used when encoding/decoding messages.
"""
def __init__(self, *, loop, timeout, proxy=None):
def __init__(self, *, timeout, proxy=None):
"""
Initializes a new connection.
:param loop: the event loop to be used.
:param timeout: timeout to be used for all operations.
:param proxy: whether to use a proxy or not.
"""
self._loop = loop
self._proxy = proxy
self._timeout = timeout
@ -58,7 +56,6 @@ class Connection(abc.ABC):
def clone(self):
"""Creates a copy of this Connection."""
return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)

View File

@ -12,11 +12,11 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
def __init__(self, *, timeout, proxy=None):
super().__init__(timeout=timeout, proxy=proxy)
self._send_counter = 0
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy
timeout=self._timeout, proxy=self._proxy
)
self.read = self.conn.read
self.write = self.conn.write

View File

@ -11,8 +11,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
def __init__(self, *, timeout, proxy=None):
super().__init__(timeout=timeout, proxy=proxy)
self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))

View File

@ -1,8 +1,11 @@
import asyncio
import concurrent.futures
import logging
import queue
import socket
import time
from . import MTProtoPlainSender, authenticator
from .. import utils
from .. import syncio, utils
from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
rpc_message_to_error
@ -40,12 +43,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, state, connection, loop, *,
def __init__(self, state, connection, *,
retries=5, auto_reconnect=True, update_callback=None,
auth_key_callback=None, auto_reconnect_callback=None):
self.state = state
self._connection = connection
self._loop = loop
self._ip = None
self._port = None
self._retries = retries
@ -160,11 +162,11 @@ class MTProtoSender:
if self._send_loop_handle:
__log__.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
self._send_loop_handle.join()
if self._recv_loop_handle:
__log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
self._recv_loop_handle.join()
__log__.info('Disconnection from {} complete!'.format(self._ip))
if self._disconnected:
@ -224,7 +226,7 @@ class MTProtoSender:
ends, either by user action or in the background.
"""
if self._disconnected is not None:
return asyncio.shield(self._disconnected, loop=self._loop)
return self._disconnected
else:
raise ConnectionError('Sender was never connected')
@ -241,7 +243,7 @@ class MTProtoSender:
try:
__log__.debug('Connection attempt {}...'.format(retry))
self._connection.connect(self._ip, self._port)
except (asyncio.TimeoutError, OSError) as e:
except (socket.timeout, OSError) as e:
__log__.warning('Attempt {} at connecting failed: {}: {}'
.format(retry, type(e).__name__, e))
else:
@ -273,14 +275,14 @@ class MTProtoSender:
raise e
__log__.debug('Starting send loop')
self._send_loop_handle = self._loop.create_task(self._send_loop())
self._send_loop_handle = syncio.create_task(self._send_loop)
__log__.debug('Starting receive loop')
self._recv_loop_handle = self._loop.create_task(self._recv_loop())
self._recv_loop_handle = syncio.create_task(self._recv_loop)
# First connection or manual reconnection after a failure
if self._disconnected is None or self._disconnected.done():
self._disconnected = asyncio.Future()
self._disconnected = concurrent.futures.Future()
__log__.info('Connection to {} complete!'.format(self._ip))
def _reconnect(self):
@ -291,10 +293,10 @@ class MTProtoSender:
self._send_queue.put_nowait(_reconnect_sentinel)
__log__.debug('Awaiting for the send loop before reconnecting...')
self._send_loop_handle
self._send_loop_handle.join()
__log__.debug('Awaiting for the receive loop before reconnecting...')
self._recv_loop_handle
self._recv_loop_handle.join()
__log__.debug('Closing current connection...')
self._connection.close()
@ -309,7 +311,7 @@ class MTProtoSender:
self._send_queue.put_nowait(m)
if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback())
syncio.create_task(self._auto_reconnect_callback)
break
except ConnectionError:
@ -321,7 +323,7 @@ class MTProtoSender:
def _start_reconnect(self):
"""Starts a reconnection in the background."""
if self._user_connected:
self._loop.create_task(self._reconnect())
syncio.create_task(self._reconnect)
def _clean_containers(self, msg_ids):
"""
@ -357,7 +359,11 @@ class MTProtoSender:
self._send_queue.put_nowait(self._last_ack)
self._pending_ack.clear()
messages = self._send_queue.get()
try:
messages = self._send_queue.get(timeout=1)
except queue.Empty:
continue
if messages == _reconnect_sentinel:
if self._reconnecting:
break
@ -383,9 +389,9 @@ class MTProtoSender:
__log__.debug('Sending {} bytes...'.format(len(body)))
self._connection.send(body)
break
except asyncio.TimeoutError:
except socket.timeout:
continue
except asyncio.CancelledError:
except concurrent.futures.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
@ -394,7 +400,7 @@ class MTProtoSender:
__log__.warning('OSError while sending %s', e)
else:
__log__.exception('Unhandled exception while receiving')
asyncio.sleep(1)
time.sleep(1)
self._start_reconnect()
break
@ -422,9 +428,9 @@ class MTProtoSender:
try:
__log__.debug('Receiving items from the network...')
body = self._connection.recv()
except asyncio.TimeoutError:
except socket.timeout:
continue
except asyncio.CancelledError:
except concurrent.futures.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
@ -433,7 +439,7 @@ class MTProtoSender:
__log__.warning('OSError while receiving %s', e)
else:
__log__.exception('Unhandled exception while receiving')
asyncio.sleep(1)
time.sleep(1)
self._start_reconnect()
break
@ -469,16 +475,16 @@ class MTProtoSender:
continue
except:
__log__.exception('Unhandled exception while unpacking')
asyncio.sleep(1)
time.sleep(1)
else:
try:
self._process_message(message)
except asyncio.CancelledError:
except concurrent.futures.CancelledError:
return
except:
__log__.exception('Unhandled exception while '
'processing %s', message)
asyncio.sleep(1)
time.sleep(1)
# Response Handlers
@ -730,19 +736,19 @@ class MTProtoSender:
"""
class _ContainerQueue(asyncio.Queue):
class _ContainerQueue(queue.Queue):
"""
An asyncio queue that's aware of `MessageContainer` instances.
A queue.Queue that's aware of `MessageContainer` instances.
The `get` method returns either a single `TLMessage` or a list
of them that should be turned into a new `MessageContainer`.
Instances of this class can be replaced with the simpler
``asyncio.Queue`` when needed for testing purposes, and
``queue.Queue`` when needed for testing purposes, and
a list won't be returned in said case.
"""
def get(self):
result = super().get()
def get(self, block=True, timeout=None):
result = super().get(block=block, timeout=timeout)
if self.empty() or result == _reconnect_sentinel or\
isinstance(result.obj, MessageContainer):
return result

8
telethon/syncio.py Normal file
View File

@ -0,0 +1,8 @@
import threading
def create_task(method, *args, **kwargs):
thread = threading.Thread(target=method, daemon=True,
args=args, kwargs=kwargs)
thread.start()
return thread

View File

@ -1,4 +1,4 @@
import asyncio
import concurrent.futures
import struct
from .gzippacked import GzipPacked
@ -26,7 +26,7 @@ class TLMessage(TLObject):
self.seq_no = seq_no
self.obj = obj
self.container_msg_id = None
self.future = asyncio.Future()
self.future = concurrent.futures.Future()
# After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can