diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index a8817a06..674a4e03 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -28,23 +28,15 @@ class MtProtoSender: self._need_confirmation = [] # Message IDs that need confirmation self._pending_receive = [] # Requests sent waiting to be received - # Store an RLock instance to make this class safely multi-threaded - self._lock = RLock() + # Sending and receiving are independent, but two threads cannot + # send or receive at the same time no matter what. + self._send_lock = RLock() + self._recv_lock = RLock() # Used when logging out, the only request that seems to use 'ack' # TODO There might be a better way to handle msgs_ack requests self.logging_out = False - # Every unhandled result gets passed to these callbacks, which - # should be functions accepting a single parameter: a TLObject. - # This should only be Update(s), although it can actually be any type. - # - # The thread from which these callbacks are called can be any. - # - # The creator of the MtProtoSender is responsible for setting this - # to point to the list wherever their callbacks reside. - self.unhandled_callbacks = None - def connect(self): """Connects to the server""" self.connection.connect() @@ -62,23 +54,17 @@ class MtProtoSender: """Sends the specified MTProtoRequest, previously sending any message which needed confirmation.""" - # Now only us can be using this method - with self._lock: - self._logger.debug('send() acquired the lock') + # If any message needs confirmation send an AckRequest first + self._send_acknowledges() - # If any message needs confirmation send an AckRequest first - self._send_acknowledges() + # Finally send our packed request + with BinaryWriter() as writer: + request.on_send(writer) + self._send_packet(writer.get_bytes(), request) + self._pending_receive.append(request) - # Finally send our packed request - with BinaryWriter() as writer: - request.on_send(writer) - self._send_packet(writer.get_bytes(), request) - self._pending_receive.append(request) - - # And update the saved session - self.session.save() - - self._logger.debug('send() released the lock') + # And update the saved session + self.session.save() def _send_acknowledges(self): """Sends a messages acknowledge for all those who _need_confirmation""" @@ -90,23 +76,22 @@ class MtProtoSender: del self._need_confirmation[:] - def receive(self): + def receive(self, update_state): """Receives a single message from the connected endpoint. This method returns nothing, and will only affect other parts of the MtProtoSender such as the updates callback being fired or a pending request being confirmed. + + Any unhandled object (likely updates) will be passed to + update_state.process(TLObject). """ - # TODO Don't ignore updates - self._logger.debug('Receiving a message...') - body = self.connection.recv() + with self._recv_lock: + body = self.connection.recv() + message, remote_msg_id, remote_seq = self._decode_msg(body) - with BinaryReader(message) as reader: - self._process_msg( - remote_msg_id, remote_seq, reader, updates=None) - - self._logger.debug('Received message.') + self._process_msg(remote_msg_id, remote_seq, reader, update_state) # endregion @@ -115,8 +100,6 @@ class MtProtoSender: def _send_packet(self, packet, request): """Sends the given packet bytes with the additional information of the original request. - - This does NOT lock the threads! """ request.request_msg_id = self.session.get_new_msg_id() @@ -142,7 +125,8 @@ class MtProtoSender: self.session.auth_key.key_id, signed=False) cipher_writer.write(msg_key) cipher_writer.write(cipher_text) - self.connection.send(cipher_writer.get_bytes()) + with self._send_lock: + self.connection.send(cipher_writer.get_bytes()) def _decode_msg(self, body): """Decodes an received encrypted message body bytes""" @@ -172,7 +156,7 @@ class MtProtoSender: return message, remote_msg_id, remote_sequence - def _process_msg(self, msg_id, sequence, reader, updates): + def _process_msg(self, msg_id, sequence, reader, state): """Processes and handles a Telegram message. Returns True if the message was handled correctly and doesn't @@ -193,10 +177,10 @@ class MtProtoSender: return self._handle_pong(msg_id, sequence, reader) if code == 0x73f1f8dc: # msg_container - return self._handle_container(msg_id, sequence, reader, updates) + return self._handle_container(msg_id, sequence, reader, state) if code == 0x3072cfa1: # gzip_packed - return self._handle_gzip_packed(msg_id, sequence, reader, updates) + return self._handle_gzip_packed(msg_id, sequence, reader, state) if code == 0xedab447b: # bad_server_salt return self._handle_bad_server_salt(msg_id, sequence, reader) @@ -221,16 +205,15 @@ class MtProtoSender: # If the code is not parsed manually then it should be a TLObject. if code in tlobjects: result = reader.tgread_object() - if self.unhandled_callbacks: - self._logger.debug( - 'Passing TLObject to callbacks %s', repr(result) - ) - for callback in self.unhandled_callbacks: - callback(result) - else: + if state is None: self._logger.debug( 'Ignoring unhandled TLObject %s', repr(result) ) + else: + self._logger.debug( + 'Processing TLObject %s', repr(result) + ) + state.process(result) return True @@ -261,7 +244,7 @@ class MtProtoSender: return True - def _handle_container(self, msg_id, sequence, reader, updates): + def _handle_container(self, msg_id, sequence, reader, state): self._logger.debug('Handling container') reader.read_int(signed=False) # code size = reader.read_int() @@ -274,8 +257,7 @@ class MtProtoSender: # Note that this code is IMPORTANT for skipping RPC results of # lost requests (i.e., ones from the previous connection session) try: - if not self._process_msg( - inner_msg_id, sequence, reader, updates): + if not self._process_msg(inner_msg_id, sequence, reader, state): reader.set_position(begin_position + inner_length) except: # If any error is raised, something went wrong; skip the packet @@ -366,14 +348,13 @@ class MtProtoSender: self._logger.debug('Lost request will be skipped.') return False - def _handle_gzip_packed(self, msg_id, sequence, reader, updates): + def _handle_gzip_packed(self, msg_id, sequence, reader, state): self._logger.debug('Handling gzip packed data') reader.read_int(signed=False) # code packed_data = reader.tgread_bytes() unpacked_data = gzip.decompress(packed_data) with BinaryReader(unpacked_data) as compressed_reader: - return self._process_msg( - msg_id, sequence, compressed_reader, updates) + return self._process_msg(msg_id, sequence, compressed_reader, state) # endregion diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 8fd94a9d..e86d15f4 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -26,6 +26,7 @@ from .tl.functions.upload import ( ) from .tl.types import InputFile, InputFileBig from .tl.types.upload import FileCdnRedirect +from .update_state import UpdateState from .utils import get_appropriated_part_size @@ -56,7 +57,9 @@ class TelegramBareClient: def __init__(self, session, api_id, api_hash, connection_mode=ConnectionMode.TCP_FULL, - proxy=None, timeout=timedelta(seconds=5)): + proxy=None, + process_updates=False, + timeout=timedelta(seconds=5)): """Initializes the Telegram client with the specified API ID and Hash. Session must always be a Session instance, and an optional proxy can also be specified to be used on the connection. @@ -74,11 +77,9 @@ class TelegramBareClient: # the time since it's a (somewhat expensive) process. self._cached_clients = {} - # Update callbacks (functions accepting a single TLObject) go here - # - # Note that changing the list to which this variable points to - # will not reflect the changes on the existing senders. - self._update_callbacks = [] + # This member will process updates if enabled. + # One may change self.updates.enabled at any later point. + self.updates = UpdateState(process_updates) # These will be set later self.dc_options = None @@ -127,7 +128,6 @@ class TelegramBareClient: self.session.save() self._sender = MtProtoSender(connection, self.session) - self._sender.unhandled_callbacks = self._update_callbacks self._sender.connect() # Now it's time to send an InitConnectionRequest @@ -312,7 +312,7 @@ class TelegramBareClient: request.confirm_received.wait() # TODO Socket's timeout here? else: while not request.confirm_received.is_set(): - self._sender.receive() + self._sender.receive(update_state=self.updates) if request.rpc_error: raise request.rpc_error diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index a7109ced..babc5499 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -1,8 +1,8 @@ import os +import threading from datetime import datetime, timedelta from mimetypes import guess_type -from threading import RLock, Thread -import threading +from threading import Thread from . import TelegramBareClient from . import helpers as utils @@ -13,6 +13,7 @@ from .errors import ( ) from .network import ConnectionMode from .tl import Session, TLObject +from .tl.functions import PingRequest from .tl.functions.account import ( GetPasswordRequest ) @@ -27,6 +28,9 @@ from .tl.functions.messages import ( GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest, SendMessageRequest ) +from .tl.functions.updates import ( + GetStateRequest +) from .tl.functions.users import ( GetUsersRequest ) @@ -46,9 +50,6 @@ class TelegramClient(TelegramBareClient): As opposed to the TelegramBareClient, this one features downloading media from different data centers, starting a second thread to handle updates, and some very common functionality. - - This should be used when the (slight) overhead of having locks, - threads, and possibly multiple connections is not an issue. """ # region Initialization @@ -56,6 +57,7 @@ class TelegramClient(TelegramBareClient): def __init__(self, session, api_id, api_hash, connection_mode=ConnectionMode.TCP_FULL, proxy=None, + process_updates=False, timeout=timedelta(seconds=5), **kwargs): """Initializes the Telegram client with the specified API ID and Hash. @@ -69,6 +71,16 @@ class TelegramClient(TelegramBareClient): This will only affect how messages are sent over the network and how much processing is required before sending them. + If 'process_updates' is set to True, incoming updates will be + processed and you must manually call 'self.updates.poll()' from + another thread to retrieve the saved update objects, or your + memory will fill with these. You may modify the value of + 'self.updates.polling' at any later point. + + Despite the value of 'process_updates', if you later call + '.add_update_handler(...)', updates will also be processed + and the update objects will be passed to the handlers you added. + If more named arguments are provided as **kwargs, they will be used to update the Session instance. Most common settings are: device_model = platform.node() @@ -92,12 +104,12 @@ class TelegramClient(TelegramBareClient): super().__init__( session, api_id, api_hash, - connection_mode=connection_mode, proxy=proxy, timeout=timeout + connection_mode=connection_mode, + proxy=proxy, + process_updates=process_updates, + timeout=timeout ) - # Safety across multiple threads (for the updates thread) - self._lock = RLock() - # Used on connection - the user may modify these and reconnect kwargs['app_version'] = kwargs.get('app_version', self.__version__) for name, value in kwargs.items(): @@ -114,6 +126,10 @@ class TelegramClient(TelegramBareClient): # Constantly read for results and updates from within the main client self._recv_thread = None + # Default PingRequest delay + self._last_ping = datetime.now() + self._ping_delay = timedelta(minutes=1) + # endregion # region Connecting @@ -145,6 +161,8 @@ class TelegramClient(TelegramBareClient): target=self._recv_thread_impl ) self._recv_thread.start() + if self.updates.enabled: + self.sync_updates() return ok @@ -205,19 +223,20 @@ class TelegramClient(TelegramBareClient): *args will be ignored. """ - try: - self._lock.acquire() + if self._recv_thread is not None and \ + threading.get_ident() == self._recv_thread.ident: + raise AssertionError('Cannot invoke requests from the ReadThread') + try: # Users may call this method from within some update handler. # If this is the case, then the thread invoking the request # will be the one which should be reading (but is invoking the # request) thus not being available to read it "in the background" # and it's needed to call receive. - call_receive = self._recv_thread is None or \ - threading.get_ident() == self._recv_thread.ident - # TODO Retry if 'result' is None? - return super().invoke(request, call_receive=call_receive) + return super().invoke( + request, call_receive=self._recv_thread is None + ) except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: self._logger.debug('DC error when invoking request, ' @@ -227,9 +246,6 @@ class TelegramClient(TelegramBareClient): self.reconnect(new_dc=e.new_dc) return self.invoke(request) - finally: - self._lock.release() - # Let people use client(SomeRequest()) instead client.invoke(...) __call__ = invoke @@ -404,8 +420,6 @@ class TelegramClient(TelegramBareClient): no_webpage=not link_preview ) result = self(request) - for callback in self._update_callbacks: - callback(result) return request.random_id def get_message_history(self, @@ -893,16 +907,26 @@ class TelegramClient(TelegramBareClient): # region Updates handling + def sync_updates(self): + """Synchronizes self.updates to their initial state. Will be + called automatically on connection if self.updates.enabled = True, + otherwise it should be called manually after enabling updates. + """ + self.updates.process(self(GetStateRequest())) + def add_update_handler(self, handler): """Adds an update handler (a function which takes a TLObject, an update, as its parameter) and listens for updates""" - self._update_callbacks.append(handler) + sync = not self.updates.handlers + self.updates.handlers.append(handler) + if sync: + self.sync_updates() def remove_update_handler(self, handler): - self._update_callbacks.remove(handler) + self.updates.handlers.remove(handler) def list_update_handlers(self): - return self._update_callbacks[:] + return self.updates.handlers[:] # endregion @@ -918,7 +942,13 @@ class TelegramClient(TelegramBareClient): def _recv_thread_impl(self): while self._sender and self._sender.is_connected(): try: - self._sender.receive() + if datetime.now() > self._last_ping + self._ping_delay: + self._sender.send(PingRequest( + int.from_bytes(os.urandom(8), 'big', signed=True) + )) + self._last_ping = datetime.now() + + self._sender.receive(update_state=self.updates) except TimeoutError: # No problem. pass diff --git a/telethon/update_state.py b/telethon/update_state.py new file mode 100644 index 00000000..3ceb87cb --- /dev/null +++ b/telethon/update_state.py @@ -0,0 +1,67 @@ +from collections import deque +from datetime import datetime +from threading import RLock, Event + +from .tl import types as tl + + +class UpdateState: + """Used to hold the current state of processed updates. + To retrieve an update, .poll() should be called. + """ + def __init__(self, polling): + self._polling = polling + self.handlers = [] + self._updates_lock = RLock() + self._updates_available = Event() + self._updates = deque() + + # https://core.telegram.org/api/updates + self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) + + def can_poll(self): + """Returns True if a call to .poll() won't lock""" + return self._updates_available.is_set() + + def poll(self): + """Polls an update or blocks until an update object is available""" + if not self._polling: + raise ValueError('Updates are not being polled hence not saved.') + + self._updates_available.wait() + with self._updates_lock: + update = self._updates.popleft() + if not self._updates: + self._updates_available.clear() + + return update + + def get_polling(self): + return self._polling + + def set_polling(self, polling): + self._polling = polling + if not polling: + with self._updates_lock: + self._updates.clear() + + polling = property(fget=get_polling, fset=set_polling) + + def process(self, update): + """Processes an update object. This method is normally called by + the library itself. + """ + if not self._polling or not self.handlers: + return + + with self._updates_lock: + if isinstance(update, tl.updates.State): + self._state = update + elif not hasattr(update, 'pts') or update.pts > self._state.pts: + self._state.pts = getattr(update, 'pts', self._state.pts) + for handler in self.handlers: + handler(update) + + if self._polling: + self._updates.append(update) + self._updates_available.set() diff --git a/telethon_examples/interactive_telegram_client.py b/telethon_examples/interactive_telegram_client.py index ae858cb7..b319dc56 100644 --- a/telethon_examples/interactive_telegram_client.py +++ b/telethon_examples/interactive_telegram_client.py @@ -51,7 +51,8 @@ class InteractiveTelegramClient(TelegramClient): print('Initializing interactive example...') super().__init__( session_user_id, api_id, api_hash, - connection_mode=ConnectionMode.TCP_ABRIDGED, proxy=proxy + connection_mode=ConnectionMode.TCP_ABRIDGED, + proxy=proxy ) # Store all the found media in memory here,