diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 02c26996..bd969084 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -10,14 +10,15 @@ from ..errors import ReadCancelledError class TcpClient: - def __init__(self, proxy=None): + def __init__(self, proxy=None, timeout=timedelta(seconds=5)): self._proxy = proxy self._socket = None - - # Support for multi-threading advantages and safety - self.cancelled = Event() # Has the read operation been cancelled? - self.delay = 0.1 # Read delay when there was no data available - self._lock = Lock() + if isinstance(timeout, timedelta): + self._timeout = timeout.seconds + elif isinstance(timeout, int) or isinstance(timeout, float): + self._timeout = float(timeout) + else: + raise ValueError('Invalid timeout type', type(timeout)) def _recreate_socket(self, mode): if self._proxy is None: @@ -30,20 +31,19 @@ class TcpClient: else: # tuple, list, etc. self._socket.set_proxy(*self._proxy) - def connect(self, ip, port, timeout): + def connect(self, ip, port): """Connects to the specified IP and port number. 'timeout' must be given in seconds """ if not self.connected: if ':' in ip: # IPv6 - self._recreate_socket(socket.AF_INET6) - self._socket.settimeout(timeout) - self._socket.connect((ip, port, 0, 0)) + mode, address = socket.AF_INET6, (ip, port, 0, 0) else: - self._recreate_socket(socket.AF_INET) - self._socket.settimeout(timeout) - self._socket.connect((ip, port)) - self._socket.setblocking(False) + mode, address = socket.AF_INET, (ip, port) + + self._recreate_socket(mode) + self._socket.settimeout(self._timeout) + self._socket.connect(address) def _get_connected(self): return self._socket is not None @@ -65,27 +65,15 @@ class TcpClient: def write(self, data): """Writes (sends) the specified bytes to the connected peer""" - # Ensure that only one thread can send data at once - with self._lock: - try: - view = memoryview(data) - total_sent, total = 0, len(data) - while total_sent < total: - try: - sent = self._socket.send(view[total_sent:]) - if sent == 0: - self.close() - raise ConnectionResetError( - 'The server has closed the connection.') - total_sent += sent + # TODO Timeout may be an issue when sending the data, Changed in v3.5: + # The socket timeout is now the maximum total duration to send all data. + try: + self._socket.sendall(data) + except BrokenPipeError: + self.close() + raise - except BlockingIOError: - time.sleep(self.delay) - except BrokenPipeError: - self.close() - raise - - def read(self, size, timeout=timedelta(seconds=5)): + def read(self, size): """Reads (receives) a whole block of 'size bytes from the connected peer. @@ -94,50 +82,19 @@ class TcpClient: and it's waiting for more, the timeout will NOT cancel the operation. Set to None for no timeout """ + # TODO Remove the timeout from this method, always use previous one + with BufferedWriter(BytesIO(), buffer_size=size) as buffer: + bytes_left = size + while bytes_left != 0: + partial = self._socket.recv(bytes_left) + if len(partial) == 0: + self.close() + raise ConnectionResetError( + 'The server has closed the connection.') - # Ensure that only one thread can receive data at once - with self._lock: - # Ensure it is not cancelled at first, so we can enter the loop - self.cancelled.clear() + buffer.write(partial) + bytes_left -= len(partial) - # Set the starting time so we can - # calculate whether the timeout should fire - start_time = datetime.now() if timeout is not None else None - - with BufferedWriter(BytesIO(), buffer_size=size) as buffer: - bytes_left = size - while bytes_left != 0: - # Only do cancel if no data was read yet - # Otherwise, carry on reading and finish - if self.cancelled.is_set() and bytes_left == size: - raise ReadCancelledError() - - try: - partial = self._socket.recv(bytes_left) - if len(partial) == 0: - self.close() - raise ConnectionResetError( - 'The server has closed the connection.') - - buffer.write(partial) - bytes_left -= len(partial) - - except BlockingIOError as error: - # No data available yet, sleep a bit - time.sleep(self.delay) - - # Check if the timeout finished - if timeout is not None: - time_passed = datetime.now() - start_time - if time_passed > timeout: - raise TimeoutError( - 'The read operation exceeded the timeout.') from error - - # If everything went fine, return the read bytes - buffer.flush() - return buffer.raw.getvalue() - - def cancel_read(self): - """Cancels the read operation IF it hasn't yet - started, raising a ReadCancelledError""" - self.cancelled.set() + # If everything went fine, return the read bytes + buffer.flush() + return buffer.raw.getvalue() diff --git a/telethon/network/connection.py b/telethon/network/connection.py index af3ab817..2116464a 100644 --- a/telethon/network/connection.py +++ b/telethon/network/connection.py @@ -22,13 +22,12 @@ class Connection: self.ip = ip self.port = port self._mode = mode - self.timeout = timeout self._send_counter = 0 self._aes_encrypt, self._aes_decrypt = None, None # TODO Rename "TcpClient" as some sort of generic socket? - self.conn = TcpClient(proxy=proxy) + self.conn = TcpClient(proxy=proxy, timeout=timeout) # Sending messages if mode == 'tcp_full': @@ -53,8 +52,7 @@ class Connection: def connect(self): self._send_counter = 0 - self.conn.connect(self.ip, self.port, - timeout=round(self.timeout.seconds)) + self.conn.connect(self.ip, self.port) if self._mode == 'tcp_abridged': self.conn.write(b'\xef') @@ -96,24 +94,18 @@ class Connection: def close(self): self.conn.close() - def cancel_receive(self): - """Cancels (stops) trying to receive from the - remote peer and raises a ReadCancelledError""" - self.conn.cancel_read() - def get_client_delay(self): """Gets the client read delay""" return self.conn.delay # region Receive message implementations - def recv(self, **kwargs): + def recv(self): """Receives and unpacks a message""" - # TODO Don't ignore kwargs['timeout']? # Default implementation is just an error raise ValueError('Invalid connection mode specified: ' + self._mode) - def _recv_tcp_full(self, **kwargs): + def _recv_tcp_full(self): packet_length_bytes = self.read(4) packet_length = int.from_bytes(packet_length_bytes, 'little') @@ -129,10 +121,10 @@ class Connection: return body - def _recv_intermediate(self, **kwargs): + def _recv_intermediate(self): return self.read(int.from_bytes(self.read(4), 'little')) - def _recv_abridged(self, **kwargs): + def _recv_abridged(self): length = int.from_bytes(self.read(1), 'little') if length >= 127: length = int.from_bytes(self.read(3) + b'\0', 'little') @@ -185,11 +177,11 @@ class Connection: raise ValueError('Invalid connection mode specified: ' + self._mode) def _read_plain(self, length): - return self.conn.read(length, timeout=self.timeout) + return self.conn.read(length) def _read_obfuscated(self, length): return self._aes_decrypt.encrypt( - self.conn.read(length, timeout=self.timeout) + self.conn.read(length) ) # endregion diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 4345908e..81b448e1 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -1,6 +1,5 @@ import gzip -from datetime import timedelta -from threading import RLock +from threading import RLock, Thread from .. import helpers as utils from ..crypto import AES @@ -14,9 +13,22 @@ logging.getLogger(__name__).addHandler(logging.NullHandler()) class MtProtoSender: - """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" + """MTProto Mobile Protocol sender + (https://core.telegram.org/mtproto/description) + """ - def __init__(self, connection, session): + def __init__(self, connection, session, constant_read): + """Creates a new MtProtoSender configured to send messages through + 'connection' and using the parameters from 'session'. + + If 'constant_read' is set to True, another thread will be + created and started upon connection to constantly read + from the other end. Otherwise, manual calls to .receive() + must be performed. The MtProtoSender cannot be connected, + or an error will be thrown. + + This way, sending and receiving will be completely independent. + """ self.connection = connection self.session = session self._logger = logging.getLogger(__name__) @@ -31,16 +43,45 @@ class MtProtoSender: # TODO There might be a better way to handle msgs_ack requests self.logging_out = False + # Will create a new _recv_thread when connecting if set + self._constant_read = constant_read + self._recv_thread = None + + # Every unhandled result gets passed to these callbacks, which + # should be functions accepting a single parameter: a TLObject. + # This should only be Update(s), although it can actually be any type. + # + # The thread from which these callbacks are called can be any. + # + # The creator of the MtProtoSender is responsible for setting this + # to point to the list wherever their callbacks reside. + self.unhandled_callbacks = None + def connect(self): """Connects to the server""" - self.connection.connect() + if not self.is_connected(): + self.connection.connect() + if self._constant_read: + self._recv_thread = Thread( + name='ReadThread', daemon=True, + target=self._recv_thread_impl + ) + self._recv_thread.start() def is_connected(self): return self.connection.is_connected() def disconnect(self): """Disconnects from the server""" - self.connection.close() + if self.is_connected(): + self.connection.close() + if self._constant_read: + # The existing thread will close eventually, since it's + # only running while the MtProtoSender.is_connected() + self._recv_thread = None + + def is_constant_read(self): + return self._constant_read # region Send and receive @@ -76,57 +117,31 @@ class MtProtoSender: del self._need_confirmation[:] - def receive(self, request=None, updates=None, **kwargs): - """Receives the specified MTProtoRequest ("fills in it" - the received data). This also restores the updates thread. + def _recv_thread_impl(self): + while self.is_connected(): + try: + self.receive() + except TimeoutError: + # No problem. + pass - An optional named parameter 'timeout' can be specified if - one desires to override 'self.connection.timeout'. + def receive(self): + """Receives a single message from the connected endpoint. - If 'request' is None, a single item will be read into - the 'updates' list (which cannot be None). - - If 'request' is not None, any update received before - reading the request's result will be put there unless - it's None, in which case updates will be ignored. + This method returns nothing, and will only affect other parts + of the MtProtoSender such as the updates callback being fired + or a pending request being confirmed. """ - if request is None and updates is None: - raise ValueError('Both the "request" and "updates"' - 'parameters cannot be None at the same time.') + # TODO Don't ignore updates + self._logger.debug('Receiving a message...') + body = self.connection.recv() + message, remote_msg_id, remote_seq = self._decode_msg(body) - with self._lock: - self._logger.debug('receive() acquired the lock') - # Don't stop trying to receive until we get the request we wanted - # or, if there is no request, until we read an update - while (request and not request.confirm_received) or \ - (not request and not updates): - self._logger.debug('Trying to .receive() the request result...') - body = self.connection.recv(**kwargs) - message, remote_msg_id, remote_seq = self._decode_msg(body) + with BinaryReader(message) as reader: + self._process_msg( + remote_msg_id, remote_seq, reader, updates=None) - with BinaryReader(message) as reader: - self._process_msg( - remote_msg_id, remote_seq, reader, updates) - - # We're done receiving, remove the request from pending, if any - if request: - try: - self._pending_receive.remove(request) - except ValueError: pass - - self._logger.debug('Request result received') - self._logger.debug('receive() released the lock') - - def receive_updates(self, **kwargs): - """Wrapper for .receive(request=None, updates=[])""" - updates = [] - self.receive(updates=updates, **kwargs) - return updates - - def cancel_receive(self): - """Cancels any pending receive operation - by raising a ReadCancelledError""" - self.connection.cancel_receive() + self._logger.debug('Received message.') # endregion @@ -230,20 +245,19 @@ class MtProtoSender: if self.logging_out: self._logger.debug('Message ack confirmed a request') - r.confirm_received = True + r.confirm_received.set() return True - # If the code is not parsed manually, then it was parsed by the code generator! - # In this case, we will simply treat the incoming TLObject as an Update, - # if we can first find a matching TLObject + # If the code is not parsed manually then it should be a TLObject. if code in tlobjects: result = reader.tgread_object() - if updates is None: - self._logger.debug('Ignored update for %s', repr(result)) + if self.unhandled_callbacks: + self._logger.debug('Passing TLObject to callbacks %s', repr(result)) + for callback in self.unhandled_callbacks: + callback(result) else: - self._logger.debug('Read update for %s', repr(result)) - updates.append(result) + self._logger.debug('Ignoring unhandled TLObject %s', repr(result)) return True @@ -264,7 +278,7 @@ class MtProtoSender: if r.request_msg_id == received_msg_id) self._logger.debug('Pong confirmed a request') - request.confirm_received = True + request.confirm_received.set() except StopIteration: pass return True @@ -338,8 +352,6 @@ class MtProtoSender: try: request = next(r for r in self._pending_receive if r.request_msg_id == request_id) - - request.confirm_received = True except StopIteration: request = None @@ -358,13 +370,12 @@ class MtProtoSender: self._need_confirmation.append(request_id) self._send_acknowledges() + if request: + request.error = error + request.confirm_received.set() + # else TODO Where should this error be reported? + # Read may be async. Can an error not-belong to a request? self._logger.debug('Read RPC error: %s', str(error)) - if isinstance(error, InvalidDCError): - # Must resend this request, if any - if request: - request.confirm_received = False - - raise error else: if request: self._logger.debug('Reading request response') @@ -376,6 +387,7 @@ class MtProtoSender: reader.seek(-4) request.on_response(reader) + request.confirm_received.set() return True else: # If it's really a result for RPC from previous connection diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 22075dc5..20b5efec 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -1,5 +1,5 @@ import logging -import pyaes +from time import sleep from datetime import timedelta from hashlib import md5 from os import path @@ -83,6 +83,12 @@ class TelegramBareClient: # the time since it's a (somewhat expensive) process. self._cached_clients = {} + # Update callbacks (functions accepting a single TLObject) go here + # + # Note that changing the list to which this variable points to + # will not reflect the changes on the existing senders. + self._update_callbacks = [] + # These will be set later self.dc_options = None self._sender = None @@ -91,7 +97,8 @@ class TelegramBareClient: # region Connecting - def connect(self, exported_auth=None, initial_query=None): + def connect(self, exported_auth=None, initial_query=None, + constant_read=False): """Connects to the Telegram servers, executing authentication if required. Note that authenticating to the Telegram servers is not the same as authenticating the desired user itself, which @@ -103,6 +110,9 @@ class TelegramBareClient: If 'initial_query' is not None, it will override the default 'GetConfigRequest()', and its result will be returned ONLY if the client wasn't connected already. + + The 'constant_read' parameter will be used when creating + the MtProtoSender. Refer to it for more information. """ if self._sender and self._sender.is_connected(): # Try sending a ping to make sure we're connected already @@ -129,7 +139,10 @@ class TelegramBareClient: self.session.save() - self._sender = MtProtoSender(connection, self.session) + self._sender = MtProtoSender( + connection, self.session, constant_read=constant_read + ) + self._sender.unhandled_callbacks = self._update_callbacks self._sender.connect() # Now it's time to send an InitConnectionRequest @@ -204,30 +217,6 @@ class TelegramBareClient: # endregion - # region Properties - - def set_timeout(self, timeout): - if timeout is None: - self._timeout = None - elif isinstance(timeout, int) or isinstance(timeout, float): - self._timeout = timedelta(seconds=timeout) - elif isinstance(timeout, timedelta): - self._timeout = timeout - else: - raise ValueError( - '{} is not a valid type for a timeout'.format(type(timeout)) - ) - - if self._sender: - self._sender.transport.timeout = self._timeout - - def get_timeout(self): - return self._timeout - - timeout = property(get_timeout, set_timeout) - - # endregion - # region Working with different Data Centers def _get_dc(self, dc_id, ipv6=False, cdn=False): @@ -318,7 +307,18 @@ class TelegramBareClient: try: self._sender.send(request) - self._sender.receive(request, updates=updates) + if self._sender.is_constant_read(): + # TODO This will be slightly troublesome if we allow + # switching between constant read or not on the fly. + # Must also watch out for calling .read() from two places, + # in which case a Lock would be required for .receive(). + request.confirm_received.wait() # TODO Optional timeout here? + else: + while not request.confirm_received.is_set(): + self._sender.receive() + + if request.rpc_error: + raise request.rpc_error return request.result except ConnectionResetError: diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index ba5f77f6..c85fd0f7 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -98,14 +98,6 @@ class TelegramClient(TelegramBareClient): # Safety across multiple threads (for the updates thread) self._lock = RLock() - # Updates-related members - self._update_handlers = [] - self._updates_thread_running = Event() - self._updates_thread_receiving = Event() - - self._next_ping_at = 0 - self.ping_interval = 60 # Seconds - # Used on connection - the user may modify these and reconnect kwargs['app_version'] = kwargs.get('app_version', self.__version__) for name, value in kwargs.items(): @@ -129,24 +121,22 @@ class TelegramClient(TelegramBareClient): not the same as authenticating the desired user itself, which may require a call (or several) to 'sign_in' for the first time. - The specified timeout will be used on internal .invoke()'s. - *args will be ignored. """ - result = super().connect() - - # Checking if there are update_handlers and if true, start running updates thread. - # This situation may occur on reconnecting. - if result and self._update_handlers: - self._set_updates_thread(running=True) - - return result - + # The main TelegramClient is the only one that will have + # constant_read, since it's also the only one who receives + # updates and need to be processed as soon as they occur. + # + # TODO Allow to disable this to avoid the creation of a new thread + # if the user is not going to work with updates at all? Whether to + # read constantly or not for updates needs to be known before hand, + # and further updates won't be able to be added unless allowing to + # switch the mode on the fly. + return super().connect(constant_read=True) def disconnect(self): """Disconnects from the Telegram server and stops all the spawned threads""" - self._set_updates_thread(running=False) super().disconnect() # Also disconnect all the cached senders @@ -159,7 +149,7 @@ class TelegramClient(TelegramBareClient): # region Working with different connections - def create_new_connection(self, on_dc=None): + def create_new_connection(self, on_dc=None, timeout=timedelta(seconds=5)): """Creates a new connection which can be used in parallel with the original TelegramClient. A TelegramBareClient will be returned already connected, and the caller is @@ -173,7 +163,9 @@ class TelegramClient(TelegramBareClient): """ if on_dc is None: client = TelegramBareClient( - self.session, self.api_id, self.api_hash, proxy=self.proxy) + self.session, self.api_id, self.api_hash, + proxy=self.proxy, timeout=timeout + ) client.connect() else: client = self._get_exported_client(on_dc, bypass_cache=True) @@ -187,29 +179,13 @@ class TelegramClient(TelegramBareClient): def invoke(self, request, *args): """Invokes (sends) a MTProtoRequest and returns (receives) its result. - An optional timeout can be specified to cancel the operation if no - result is received within such time, or None to disable any timeout. - *args will be ignored. """ - if self._updates_thread_receiving.is_set(): - self._sender.cancel_receive() - try: self._lock.acquire() - updates = [] if self._update_handlers else None - result = super().invoke( - request, updates=updates - ) - - if updates: - for update in updates: - for handler in self._update_handlers: - handler(update) - # TODO Retry if 'result' is None? - return result + return super().invoke(request) except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: self._logger.debug('DC error when invoking request, ' @@ -399,8 +375,8 @@ class TelegramClient(TelegramBareClient): no_webpage=not link_preview ) result = self(request) - for handler in self._update_handlers: - handler(result) + for callback in self._update_callbacks: + callback(result) return request.random_id def get_message_history(self, @@ -891,110 +867,12 @@ class TelegramClient(TelegramBareClient): def add_update_handler(self, handler): """Adds an update handler (a function which takes a TLObject, an update, as its parameter) and listens for updates""" - if not self._sender: - raise RuntimeError("You can't add update handlers until you've " - "successfully connected to the server.") - - first_handler = not self._update_handlers - self._update_handlers.append(handler) - if first_handler: - self._set_updates_thread(running=True) + self._update_callbacks.append(handler) def remove_update_handler(self, handler): - self._update_handlers.remove(handler) - if not self._update_handlers: - self._set_updates_thread(running=False) + self._update_callbacks.remove(handler) def list_update_handlers(self): - return self._update_handlers[:] - - def _set_updates_thread(self, running): - """Sets the updates thread status (running or not)""" - if running == self._updates_thread_running.is_set(): - return - - # Different state, update the saved value and behave as required - self._logger.debug('Changing updates thread running status to %s', running) - if running: - self._updates_thread_running.set() - if not self._updates_thread: - self._updates_thread = Thread( - name='UpdatesThread', daemon=True, - target=self._updates_thread_method) - - self._updates_thread.start() - else: - self._updates_thread_running.clear() - if self._updates_thread_receiving.is_set(): - self._sender.cancel_receive() - - def _updates_thread_method(self): - """This method will run until specified and listen for incoming updates""" - - # Set a reasonable timeout when checking for updates - timeout = timedelta(minutes=1) - - while self._updates_thread_running.is_set(): - # Always sleep a bit before each iteration to relax the CPU, - # since it's possible to early 'continue' the loop to reach - # the next iteration, but we still should to sleep. - sleep(0.1) - - with self._lock: - self._logger.debug('Updates thread acquired the lock') - try: - self._updates_thread_receiving.set() - self._logger.debug( - 'Trying to receive updates from the updates thread' - ) - - if time() > self._next_ping_at: - self._next_ping_at = time() + self.ping_interval - self(PingRequest(utils.generate_random_long())) - - updates = self._sender.receive_updates(timeout=timeout) - - self._updates_thread_receiving.clear() - self._logger.debug( - 'Received {} update(s) from the updates thread' - .format(len(updates)) - ) - for update in updates: - for handler in self._update_handlers: - handler(update) - - except ConnectionResetError: - self._logger.debug('Server disconnected us. Reconnecting...') - self.reconnect() - - except TimeoutError: - self._logger.debug('Receiving updates timed out') - - except ReadCancelledError: - self._logger.debug('Receiving updates cancelled') - - except BrokenPipeError: - self._logger.debug('Tcp session is broken. Reconnecting...') - self.reconnect() - - except InvalidChecksumError: - self._logger.debug('MTProto session is broken. Reconnecting...') - self.reconnect() - - except OSError: - self._logger.debug('OSError on updates thread, %s logging out', - 'was' if self._sender.logging_out else 'was not') - - if self._sender.logging_out: - # This error is okay when logging out, means we got disconnected - # TODO Not sure why this happens because we call disconnect()... - self._set_updates_thread(running=False) - else: - raise - - self._logger.debug('Updates thread released the lock') - - # Thread is over, so clean unset its variable - self._updates_thread = None + return self._update_callbacks[:] # endregion diff --git a/telethon/tl/tlobject.py b/telethon/tl/tlobject.py index 4c201125..66ed825f 100644 --- a/telethon/tl/tlobject.py +++ b/telethon/tl/tlobject.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from threading import Event class TLObject: @@ -10,7 +11,8 @@ class TLObject: self.dirty = False self.send_time = None - self.confirm_received = False + self.confirm_received = Event() + self.rpc_error = None # These should be overrode self.constructor_id = 0 @@ -23,11 +25,11 @@ class TLObject: self.sent = True def on_confirm(self): - self.confirm_received = True + self.confirm_received.set() def need_resend(self): return self.dirty or ( - self.content_related and not self.confirm_received and + self.content_related and not self.confirm_received.is_set() and datetime.now() - self.send_time > timedelta(seconds=3)) @staticmethod