Merge branch 'constant_read'

Since the secondary thread for constant read is not part of the
TelegramClient anymore, there is no need to restart it. It will
be ran when connecting again.
This commit is contained in:
Lonami Exo 2017-09-02 21:51:11 +02:00
commit 69d182815f
6 changed files with 178 additions and 337 deletions

View File

@ -10,14 +10,15 @@ from ..errors import ReadCancelledError
class TcpClient:
def __init__(self, proxy=None):
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
self._proxy = proxy
self._socket = None
# Support for multi-threading advantages and safety
self.cancelled = Event() # Has the read operation been cancelled?
self.delay = 0.1 # Read delay when there was no data available
self._lock = Lock()
if isinstance(timeout, timedelta):
self._timeout = timeout.seconds
elif isinstance(timeout, int) or isinstance(timeout, float):
self._timeout = float(timeout)
else:
raise ValueError('Invalid timeout type', type(timeout))
def _recreate_socket(self, mode):
if self._proxy is None:
@ -30,20 +31,19 @@ class TcpClient:
else: # tuple, list, etc.
self._socket.set_proxy(*self._proxy)
def connect(self, ip, port, timeout):
def connect(self, ip, port):
"""Connects to the specified IP and port number.
'timeout' must be given in seconds
"""
if not self.connected:
if ':' in ip: # IPv6
self._recreate_socket(socket.AF_INET6)
self._socket.settimeout(timeout)
self._socket.connect((ip, port, 0, 0))
mode, address = socket.AF_INET6, (ip, port, 0, 0)
else:
self._recreate_socket(socket.AF_INET)
self._socket.settimeout(timeout)
self._socket.connect((ip, port))
self._socket.setblocking(False)
mode, address = socket.AF_INET, (ip, port)
self._recreate_socket(mode)
self._socket.settimeout(self._timeout)
self._socket.connect(address)
def _get_connected(self):
return self._socket is not None
@ -65,27 +65,15 @@ class TcpClient:
def write(self, data):
"""Writes (sends) the specified bytes to the connected peer"""
# Ensure that only one thread can send data at once
with self._lock:
try:
view = memoryview(data)
total_sent, total = 0, len(data)
while total_sent < total:
try:
sent = self._socket.send(view[total_sent:])
if sent == 0:
self.close()
raise ConnectionResetError(
'The server has closed the connection.')
total_sent += sent
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try:
self._socket.sendall(data)
except BrokenPipeError:
self.close()
raise
except BlockingIOError:
time.sleep(self.delay)
except BrokenPipeError:
self.close()
raise
def read(self, size, timeout=timedelta(seconds=5)):
def read(self, size):
"""Reads (receives) a whole block of 'size bytes
from the connected peer.
@ -94,50 +82,19 @@ class TcpClient:
and it's waiting for more, the timeout will NOT cancel the
operation. Set to None for no timeout
"""
# TODO Remove the timeout from this method, always use previous one
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
while bytes_left != 0:
partial = self._socket.recv(bytes_left)
if len(partial) == 0:
self.close()
raise ConnectionResetError(
'The server has closed the connection.')
# Ensure that only one thread can receive data at once
with self._lock:
# Ensure it is not cancelled at first, so we can enter the loop
self.cancelled.clear()
buffer.write(partial)
bytes_left -= len(partial)
# Set the starting time so we can
# calculate whether the timeout should fire
start_time = datetime.now() if timeout is not None else None
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
while bytes_left != 0:
# Only do cancel if no data was read yet
# Otherwise, carry on reading and finish
if self.cancelled.is_set() and bytes_left == size:
raise ReadCancelledError()
try:
partial = self._socket.recv(bytes_left)
if len(partial) == 0:
self.close()
raise ConnectionResetError(
'The server has closed the connection.')
buffer.write(partial)
bytes_left -= len(partial)
except BlockingIOError as error:
# No data available yet, sleep a bit
time.sleep(self.delay)
# Check if the timeout finished
if timeout is not None:
time_passed = datetime.now() - start_time
if time_passed > timeout:
raise TimeoutError(
'The read operation exceeded the timeout.') from error
# If everything went fine, return the read bytes
buffer.flush()
return buffer.raw.getvalue()
def cancel_read(self):
"""Cancels the read operation IF it hasn't yet
started, raising a ReadCancelledError"""
self.cancelled.set()
# If everything went fine, return the read bytes
buffer.flush()
return buffer.raw.getvalue()

View File

@ -22,13 +22,12 @@ class Connection:
self.ip = ip
self.port = port
self._mode = mode
self.timeout = timeout
self._send_counter = 0
self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy)
self.conn = TcpClient(proxy=proxy, timeout=timeout)
# Sending messages
if mode == 'tcp_full':
@ -53,8 +52,7 @@ class Connection:
def connect(self):
self._send_counter = 0
self.conn.connect(self.ip, self.port,
timeout=round(self.timeout.seconds))
self.conn.connect(self.ip, self.port)
if self._mode == 'tcp_abridged':
self.conn.write(b'\xef')
@ -96,24 +94,18 @@ class Connection:
def close(self):
self.conn.close()
def cancel_receive(self):
"""Cancels (stops) trying to receive from the
remote peer and raises a ReadCancelledError"""
self.conn.cancel_read()
def get_client_delay(self):
"""Gets the client read delay"""
return self.conn.delay
# region Receive message implementations
def recv(self, **kwargs):
def recv(self):
"""Receives and unpacks a message"""
# TODO Don't ignore kwargs['timeout']?
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + self._mode)
def _recv_tcp_full(self, **kwargs):
def _recv_tcp_full(self):
packet_length_bytes = self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little')
@ -129,10 +121,10 @@ class Connection:
return body
def _recv_intermediate(self, **kwargs):
def _recv_intermediate(self):
return self.read(int.from_bytes(self.read(4), 'little'))
def _recv_abridged(self, **kwargs):
def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little')
if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little')
@ -185,11 +177,11 @@ class Connection:
raise ValueError('Invalid connection mode specified: ' + self._mode)
def _read_plain(self, length):
return self.conn.read(length, timeout=self.timeout)
return self.conn.read(length)
def _read_obfuscated(self, length):
return self._aes_decrypt.encrypt(
self.conn.read(length, timeout=self.timeout)
self.conn.read(length)
)
# endregion

View File

@ -1,6 +1,5 @@
import gzip
from datetime import timedelta
from threading import RLock
from threading import RLock, Thread
from .. import helpers as utils
from ..crypto import AES
@ -14,9 +13,22 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
class MtProtoSender:
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
"""MTProto Mobile Protocol sender
(https://core.telegram.org/mtproto/description)
"""
def __init__(self, connection, session):
def __init__(self, connection, session, constant_read):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
If 'constant_read' is set to True, another thread will be
created and started upon connection to constantly read
from the other end. Otherwise, manual calls to .receive()
must be performed. The MtProtoSender cannot be connected,
or an error will be thrown.
This way, sending and receiving will be completely independent.
"""
self.connection = connection
self.session = session
self._logger = logging.getLogger(__name__)
@ -31,16 +43,45 @@ class MtProtoSender:
# TODO There might be a better way to handle msgs_ack requests
self.logging_out = False
# Will create a new _recv_thread when connecting if set
self._constant_read = constant_read
self._recv_thread = None
# Every unhandled result gets passed to these callbacks, which
# should be functions accepting a single parameter: a TLObject.
# This should only be Update(s), although it can actually be any type.
#
# The thread from which these callbacks are called can be any.
#
# The creator of the MtProtoSender is responsible for setting this
# to point to the list wherever their callbacks reside.
self.unhandled_callbacks = None
def connect(self):
"""Connects to the server"""
self.connection.connect()
if not self.is_connected():
self.connection.connect()
if self._constant_read:
self._recv_thread = Thread(
name='ReadThread', daemon=True,
target=self._recv_thread_impl
)
self._recv_thread.start()
def is_connected(self):
return self.connection.is_connected()
def disconnect(self):
"""Disconnects from the server"""
self.connection.close()
if self.is_connected():
self.connection.close()
if self._constant_read:
# The existing thread will close eventually, since it's
# only running while the MtProtoSender.is_connected()
self._recv_thread = None
def is_constant_read(self):
return self._constant_read
# region Send and receive
@ -76,57 +117,31 @@ class MtProtoSender:
del self._need_confirmation[:]
def receive(self, request=None, updates=None, **kwargs):
"""Receives the specified MTProtoRequest ("fills in it"
the received data). This also restores the updates thread.
def _recv_thread_impl(self):
while self.is_connected():
try:
self.receive()
except TimeoutError:
# No problem.
pass
An optional named parameter 'timeout' can be specified if
one desires to override 'self.connection.timeout'.
def receive(self):
"""Receives a single message from the connected endpoint.
If 'request' is None, a single item will be read into
the 'updates' list (which cannot be None).
If 'request' is not None, any update received before
reading the request's result will be put there unless
it's None, in which case updates will be ignored.
This method returns nothing, and will only affect other parts
of the MtProtoSender such as the updates callback being fired
or a pending request being confirmed.
"""
if request is None and updates is None:
raise ValueError('Both the "request" and "updates"'
'parameters cannot be None at the same time.')
# TODO Don't ignore updates
self._logger.debug('Receiving a message...')
body = self.connection.recv()
message, remote_msg_id, remote_seq = self._decode_msg(body)
with self._lock:
self._logger.debug('receive() acquired the lock')
# Don't stop trying to receive until we get the request we wanted
# or, if there is no request, until we read an update
while (request and not request.confirm_received) or \
(not request and not updates):
self._logger.debug('Trying to .receive() the request result...')
body = self.connection.recv(**kwargs)
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(
remote_msg_id, remote_seq, reader, updates=None)
with BinaryReader(message) as reader:
self._process_msg(
remote_msg_id, remote_seq, reader, updates)
# We're done receiving, remove the request from pending, if any
if request:
try:
self._pending_receive.remove(request)
except ValueError: pass
self._logger.debug('Request result received')
self._logger.debug('receive() released the lock')
def receive_updates(self, **kwargs):
"""Wrapper for .receive(request=None, updates=[])"""
updates = []
self.receive(updates=updates, **kwargs)
return updates
def cancel_receive(self):
"""Cancels any pending receive operation
by raising a ReadCancelledError"""
self.connection.cancel_receive()
self._logger.debug('Received message.')
# endregion
@ -230,20 +245,19 @@ class MtProtoSender:
if self.logging_out:
self._logger.debug('Message ack confirmed a request')
r.confirm_received = True
r.confirm_received.set()
return True
# If the code is not parsed manually, then it was parsed by the code generator!
# In this case, we will simply treat the incoming TLObject as an Update,
# if we can first find a matching TLObject
# If the code is not parsed manually then it should be a TLObject.
if code in tlobjects:
result = reader.tgread_object()
if updates is None:
self._logger.debug('Ignored update for %s', repr(result))
if self.unhandled_callbacks:
self._logger.debug('Passing TLObject to callbacks %s', repr(result))
for callback in self.unhandled_callbacks:
callback(result)
else:
self._logger.debug('Read update for %s', repr(result))
updates.append(result)
self._logger.debug('Ignoring unhandled TLObject %s', repr(result))
return True
@ -264,7 +278,7 @@ class MtProtoSender:
if r.request_msg_id == received_msg_id)
self._logger.debug('Pong confirmed a request')
request.confirm_received = True
request.confirm_received.set()
except StopIteration: pass
return True
@ -338,8 +352,6 @@ class MtProtoSender:
try:
request = next(r for r in self._pending_receive
if r.request_msg_id == request_id)
request.confirm_received = True
except StopIteration:
request = None
@ -358,13 +370,12 @@ class MtProtoSender:
self._need_confirmation.append(request_id)
self._send_acknowledges()
if request:
request.error = error
request.confirm_received.set()
# else TODO Where should this error be reported?
# Read may be async. Can an error not-belong to a request?
self._logger.debug('Read RPC error: %s', str(error))
if isinstance(error, InvalidDCError):
# Must resend this request, if any
if request:
request.confirm_received = False
raise error
else:
if request:
self._logger.debug('Reading request response')
@ -376,6 +387,7 @@ class MtProtoSender:
reader.seek(-4)
request.on_response(reader)
request.confirm_received.set()
return True
else:
# If it's really a result for RPC from previous connection

View File

@ -1,5 +1,5 @@
import logging
import pyaes
from time import sleep
from datetime import timedelta
from hashlib import md5
from os import path
@ -83,6 +83,12 @@ class TelegramBareClient:
# the time since it's a (somewhat expensive) process.
self._cached_clients = {}
# Update callbacks (functions accepting a single TLObject) go here
#
# Note that changing the list to which this variable points to
# will not reflect the changes on the existing senders.
self._update_callbacks = []
# These will be set later
self.dc_options = None
self._sender = None
@ -91,7 +97,8 @@ class TelegramBareClient:
# region Connecting
def connect(self, exported_auth=None, initial_query=None):
def connect(self, exported_auth=None, initial_query=None,
constant_read=False):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
@ -103,6 +110,9 @@ class TelegramBareClient:
If 'initial_query' is not None, it will override the default
'GetConfigRequest()', and its result will be returned ONLY
if the client wasn't connected already.
The 'constant_read' parameter will be used when creating
the MtProtoSender. Refer to it for more information.
"""
if self._sender and self._sender.is_connected():
# Try sending a ping to make sure we're connected already
@ -129,7 +139,10 @@ class TelegramBareClient:
self.session.save()
self._sender = MtProtoSender(connection, self.session)
self._sender = MtProtoSender(
connection, self.session, constant_read=constant_read
)
self._sender.unhandled_callbacks = self._update_callbacks
self._sender.connect()
# Now it's time to send an InitConnectionRequest
@ -204,30 +217,6 @@ class TelegramBareClient:
# endregion
# region Properties
def set_timeout(self, timeout):
if timeout is None:
self._timeout = None
elif isinstance(timeout, int) or isinstance(timeout, float):
self._timeout = timedelta(seconds=timeout)
elif isinstance(timeout, timedelta):
self._timeout = timeout
else:
raise ValueError(
'{} is not a valid type for a timeout'.format(type(timeout))
)
if self._sender:
self._sender.transport.timeout = self._timeout
def get_timeout(self):
return self._timeout
timeout = property(get_timeout, set_timeout)
# endregion
# region Working with different Data Centers
def _get_dc(self, dc_id, ipv6=False, cdn=False):
@ -318,7 +307,18 @@ class TelegramBareClient:
try:
self._sender.send(request)
self._sender.receive(request, updates=updates)
if self._sender.is_constant_read():
# TODO This will be slightly troublesome if we allow
# switching between constant read or not on the fly.
# Must also watch out for calling .read() from two places,
# in which case a Lock would be required for .receive().
request.confirm_received.wait() # TODO Optional timeout here?
else:
while not request.confirm_received.is_set():
self._sender.receive()
if request.rpc_error:
raise request.rpc_error
return request.result
except ConnectionResetError:

View File

@ -98,14 +98,6 @@ class TelegramClient(TelegramBareClient):
# Safety across multiple threads (for the updates thread)
self._lock = RLock()
# Updates-related members
self._update_handlers = []
self._updates_thread_running = Event()
self._updates_thread_receiving = Event()
self._next_ping_at = 0
self.ping_interval = 60 # Seconds
# Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
for name, value in kwargs.items():
@ -129,24 +121,22 @@ class TelegramClient(TelegramBareClient):
not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time.
The specified timeout will be used on internal .invoke()'s.
*args will be ignored.
"""
result = super().connect()
# Checking if there are update_handlers and if true, start running updates thread.
# This situation may occur on reconnecting.
if result and self._update_handlers:
self._set_updates_thread(running=True)
return result
# The main TelegramClient is the only one that will have
# constant_read, since it's also the only one who receives
# updates and need to be processed as soon as they occur.
#
# TODO Allow to disable this to avoid the creation of a new thread
# if the user is not going to work with updates at all? Whether to
# read constantly or not for updates needs to be known before hand,
# and further updates won't be able to be added unless allowing to
# switch the mode on the fly.
return super().connect(constant_read=True)
def disconnect(self):
"""Disconnects from the Telegram server
and stops all the spawned threads"""
self._set_updates_thread(running=False)
super().disconnect()
# Also disconnect all the cached senders
@ -159,7 +149,7 @@ class TelegramClient(TelegramBareClient):
# region Working with different connections
def create_new_connection(self, on_dc=None):
def create_new_connection(self, on_dc=None, timeout=timedelta(seconds=5)):
"""Creates a new connection which can be used in parallel
with the original TelegramClient. A TelegramBareClient
will be returned already connected, and the caller is
@ -173,7 +163,9 @@ class TelegramClient(TelegramBareClient):
"""
if on_dc is None:
client = TelegramBareClient(
self.session, self.api_id, self.api_hash, proxy=self.proxy)
self.session, self.api_id, self.api_hash,
proxy=self.proxy, timeout=timeout
)
client.connect()
else:
client = self._get_exported_client(on_dc, bypass_cache=True)
@ -187,29 +179,13 @@ class TelegramClient(TelegramBareClient):
def invoke(self, request, *args):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
An optional timeout can be specified to cancel the operation if no
result is received within such time, or None to disable any timeout.
*args will be ignored.
"""
if self._updates_thread_receiving.is_set():
self._sender.cancel_receive()
try:
self._lock.acquire()
updates = [] if self._update_handlers else None
result = super().invoke(
request, updates=updates
)
if updates:
for update in updates:
for handler in self._update_handlers:
handler(update)
# TODO Retry if 'result' is None?
return result
return super().invoke(request)
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
self._logger.debug('DC error when invoking request, '
@ -399,8 +375,8 @@ class TelegramClient(TelegramBareClient):
no_webpage=not link_preview
)
result = self(request)
for handler in self._update_handlers:
handler(result)
for callback in self._update_callbacks:
callback(result)
return request.random_id
def get_message_history(self,
@ -891,110 +867,12 @@ class TelegramClient(TelegramBareClient):
def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
if not self._sender:
raise RuntimeError("You can't add update handlers until you've "
"successfully connected to the server.")
first_handler = not self._update_handlers
self._update_handlers.append(handler)
if first_handler:
self._set_updates_thread(running=True)
self._update_callbacks.append(handler)
def remove_update_handler(self, handler):
self._update_handlers.remove(handler)
if not self._update_handlers:
self._set_updates_thread(running=False)
self._update_callbacks.remove(handler)
def list_update_handlers(self):
return self._update_handlers[:]
def _set_updates_thread(self, running):
"""Sets the updates thread status (running or not)"""
if running == self._updates_thread_running.is_set():
return
# Different state, update the saved value and behave as required
self._logger.debug('Changing updates thread running status to %s', running)
if running:
self._updates_thread_running.set()
if not self._updates_thread:
self._updates_thread = Thread(
name='UpdatesThread', daemon=True,
target=self._updates_thread_method)
self._updates_thread.start()
else:
self._updates_thread_running.clear()
if self._updates_thread_receiving.is_set():
self._sender.cancel_receive()
def _updates_thread_method(self):
"""This method will run until specified and listen for incoming updates"""
# Set a reasonable timeout when checking for updates
timeout = timedelta(minutes=1)
while self._updates_thread_running.is_set():
# Always sleep a bit before each iteration to relax the CPU,
# since it's possible to early 'continue' the loop to reach
# the next iteration, but we still should to sleep.
sleep(0.1)
with self._lock:
self._logger.debug('Updates thread acquired the lock')
try:
self._updates_thread_receiving.set()
self._logger.debug(
'Trying to receive updates from the updates thread'
)
if time() > self._next_ping_at:
self._next_ping_at = time() + self.ping_interval
self(PingRequest(utils.generate_random_long()))
updates = self._sender.receive_updates(timeout=timeout)
self._updates_thread_receiving.clear()
self._logger.debug(
'Received {} update(s) from the updates thread'
.format(len(updates))
)
for update in updates:
for handler in self._update_handlers:
handler(update)
except ConnectionResetError:
self._logger.debug('Server disconnected us. Reconnecting...')
self.reconnect()
except TimeoutError:
self._logger.debug('Receiving updates timed out')
except ReadCancelledError:
self._logger.debug('Receiving updates cancelled')
except BrokenPipeError:
self._logger.debug('Tcp session is broken. Reconnecting...')
self.reconnect()
except InvalidChecksumError:
self._logger.debug('MTProto session is broken. Reconnecting...')
self.reconnect()
except OSError:
self._logger.debug('OSError on updates thread, %s logging out',
'was' if self._sender.logging_out else 'was not')
if self._sender.logging_out:
# This error is okay when logging out, means we got disconnected
# TODO Not sure why this happens because we call disconnect()...
self._set_updates_thread(running=False)
else:
raise
self._logger.debug('Updates thread released the lock')
# Thread is over, so clean unset its variable
self._updates_thread = None
return self._update_callbacks[:]
# endregion

View File

@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from threading import Event
class TLObject:
@ -10,7 +11,8 @@ class TLObject:
self.dirty = False
self.send_time = None
self.confirm_received = False
self.confirm_received = Event()
self.rpc_error = None
# These should be overrode
self.constructor_id = 0
@ -23,11 +25,11 @@ class TLObject:
self.sent = True
def on_confirm(self):
self.confirm_received = True
self.confirm_received.set()
def need_resend(self):
return self.dirty or (
self.content_related and not self.confirm_received and
self.content_related and not self.confirm_received.is_set() and
datetime.now() - self.send_time > timedelta(seconds=3))
@staticmethod