From d4f36162cd0c29d4f97f818ac5c5ce04622fa4eb Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 7 Sep 2017 18:49:08 +0200 Subject: [PATCH] Create and use UpdateState to .process() unhandled TLObjects --- telethon/network/mtproto_sender.py | 47 ++++++++++++------------------ telethon/telegram_bare_client.py | 16 +++++----- telethon/telegram_client.py | 19 +++++++++--- telethon/update_state.py | 38 ++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 41 deletions(-) create mode 100644 telethon/update_state.py diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index a8817a06..ef8b4794 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -35,16 +35,6 @@ class MtProtoSender: # TODO There might be a better way to handle msgs_ack requests self.logging_out = False - # Every unhandled result gets passed to these callbacks, which - # should be functions accepting a single parameter: a TLObject. - # This should only be Update(s), although it can actually be any type. - # - # The thread from which these callbacks are called can be any. - # - # The creator of the MtProtoSender is responsible for setting this - # to point to the list wherever their callbacks reside. - self.unhandled_callbacks = None - def connect(self): """Connects to the server""" self.connection.connect() @@ -90,12 +80,15 @@ class MtProtoSender: del self._need_confirmation[:] - def receive(self): + def receive(self, update_state): """Receives a single message from the connected endpoint. This method returns nothing, and will only affect other parts of the MtProtoSender such as the updates callback being fired or a pending request being confirmed. + + Any unhandled object (likely updates) will be passed to + update_state.process(TLObject). """ # TODO Don't ignore updates self._logger.debug('Receiving a message...') @@ -103,8 +96,7 @@ class MtProtoSender: message, remote_msg_id, remote_seq = self._decode_msg(body) with BinaryReader(message) as reader: - self._process_msg( - remote_msg_id, remote_seq, reader, updates=None) + self._process_msg(remote_msg_id, remote_seq, reader, update_state) self._logger.debug('Received message.') @@ -172,7 +164,7 @@ class MtProtoSender: return message, remote_msg_id, remote_sequence - def _process_msg(self, msg_id, sequence, reader, updates): + def _process_msg(self, msg_id, sequence, reader, state): """Processes and handles a Telegram message. Returns True if the message was handled correctly and doesn't @@ -193,10 +185,10 @@ class MtProtoSender: return self._handle_pong(msg_id, sequence, reader) if code == 0x73f1f8dc: # msg_container - return self._handle_container(msg_id, sequence, reader, updates) + return self._handle_container(msg_id, sequence, reader, state) if code == 0x3072cfa1: # gzip_packed - return self._handle_gzip_packed(msg_id, sequence, reader, updates) + return self._handle_gzip_packed(msg_id, sequence, reader, state) if code == 0xedab447b: # bad_server_salt return self._handle_bad_server_salt(msg_id, sequence, reader) @@ -221,16 +213,15 @@ class MtProtoSender: # If the code is not parsed manually then it should be a TLObject. if code in tlobjects: result = reader.tgread_object() - if self.unhandled_callbacks: - self._logger.debug( - 'Passing TLObject to callbacks %s', repr(result) - ) - for callback in self.unhandled_callbacks: - callback(result) - else: + if state is None: self._logger.debug( 'Ignoring unhandled TLObject %s', repr(result) ) + else: + self._logger.debug( + 'Processing TLObject %s', repr(result) + ) + state.process(result) return True @@ -261,7 +252,7 @@ class MtProtoSender: return True - def _handle_container(self, msg_id, sequence, reader, updates): + def _handle_container(self, msg_id, sequence, reader, state): self._logger.debug('Handling container') reader.read_int(signed=False) # code size = reader.read_int() @@ -274,8 +265,7 @@ class MtProtoSender: # Note that this code is IMPORTANT for skipping RPC results of # lost requests (i.e., ones from the previous connection session) try: - if not self._process_msg( - inner_msg_id, sequence, reader, updates): + if not self._process_msg(inner_msg_id, sequence, reader, state): reader.set_position(begin_position + inner_length) except: # If any error is raised, something went wrong; skip the packet @@ -366,14 +356,13 @@ class MtProtoSender: self._logger.debug('Lost request will be skipped.') return False - def _handle_gzip_packed(self, msg_id, sequence, reader, updates): + def _handle_gzip_packed(self, msg_id, sequence, reader, state): self._logger.debug('Handling gzip packed data') reader.read_int(signed=False) # code packed_data = reader.tgread_bytes() unpacked_data = gzip.decompress(packed_data) with BinaryReader(unpacked_data) as compressed_reader: - return self._process_msg( - msg_id, sequence, compressed_reader, updates) + return self._process_msg(msg_id, sequence, compressed_reader, state) # endregion diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 8fd94a9d..1685d849 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -26,6 +26,7 @@ from .tl.functions.upload import ( ) from .tl.types import InputFile, InputFileBig from .tl.types.upload import FileCdnRedirect +from .update_state import UpdateState from .utils import get_appropriated_part_size @@ -56,7 +57,9 @@ class TelegramBareClient: def __init__(self, session, api_id, api_hash, connection_mode=ConnectionMode.TCP_FULL, - proxy=None, timeout=timedelta(seconds=5)): + proxy=None, + enable_updates=False, + timeout=timedelta(seconds=5)): """Initializes the Telegram client with the specified API ID and Hash. Session must always be a Session instance, and an optional proxy can also be specified to be used on the connection. @@ -74,11 +77,9 @@ class TelegramBareClient: # the time since it's a (somewhat expensive) process. self._cached_clients = {} - # Update callbacks (functions accepting a single TLObject) go here - # - # Note that changing the list to which this variable points to - # will not reflect the changes on the existing senders. - self._update_callbacks = [] + # This member will process updates if enabled. + # One may change self.updates.enabled at any later point. + self.updates = UpdateState(enabled=enable_updates) # These will be set later self.dc_options = None @@ -127,7 +128,6 @@ class TelegramBareClient: self.session.save() self._sender = MtProtoSender(connection, self.session) - self._sender.unhandled_callbacks = self._update_callbacks self._sender.connect() # Now it's time to send an InitConnectionRequest @@ -312,7 +312,7 @@ class TelegramBareClient: request.confirm_received.wait() # TODO Socket's timeout here? else: while not request.confirm_received.is_set(): - self._sender.receive() + self._sender.receive(update_state=self.updates) if request.rpc_error: raise request.rpc_error diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index f35ac39a..ef83cdaa 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -56,6 +56,7 @@ class TelegramClient(TelegramBareClient): def __init__(self, session, api_id, api_hash, connection_mode=ConnectionMode.TCP_FULL, proxy=None, + enable_updates=False, timeout=timedelta(seconds=5), **kwargs): """Initializes the Telegram client with the specified API ID and Hash. @@ -69,6 +70,12 @@ class TelegramClient(TelegramBareClient): This will only affect how messages are sent over the network and how much processing is required before sending them. + If 'enable_updates' is set to True, it will by default put + all updates on self.updates. NOTE that you must manually query + this from another thread or it will eventually fill up all your + memory. If you want to ignore updates, leave this set to False. + You may change self.updates.enabled at any later point. + If more named arguments are provided as **kwargs, they will be used to update the Session instance. Most common settings are: device_model = platform.node() @@ -92,7 +99,10 @@ class TelegramClient(TelegramBareClient): super().__init__( session, api_id, api_hash, - connection_mode=connection_mode, proxy=proxy, timeout=timeout + connection_mode=connection_mode, + proxy=proxy, + enable_updates=enable_updates, + timeout=timeout ) # Safety across multiple threads (for the updates thread) @@ -407,8 +417,6 @@ class TelegramClient(TelegramBareClient): no_webpage=not link_preview ) result = self(request) - for callback in self._update_callbacks: - callback(result) return request.random_id def get_message_history(self, @@ -899,12 +907,15 @@ class TelegramClient(TelegramBareClient): def add_update_handler(self, handler): """Adds an update handler (a function which takes a TLObject, an update, as its parameter) and listens for updates""" + return # TODO Implement self._update_callbacks.append(handler) def remove_update_handler(self, handler): + return # TODO Implement self._update_callbacks.remove(handler) def list_update_handlers(self): + return # TODO Implement return self._update_callbacks[:] # endregion @@ -921,7 +932,7 @@ class TelegramClient(TelegramBareClient): def _recv_thread_impl(self): while self._sender and self._sender.is_connected(): try: - self._sender.receive() + self._sender.receive(update_state=self.updates) except TimeoutError: # No problem. pass diff --git a/telethon/update_state.py b/telethon/update_state.py new file mode 100644 index 00000000..14d362a9 --- /dev/null +++ b/telethon/update_state.py @@ -0,0 +1,38 @@ +from threading import Lock, Event +from collections import deque + + +class UpdateState: + """Used to hold the current state of processed updates. + To retrieve an update, .pop_update() should be called. + """ + def __init__(self, enabled): + self.enabled = enabled + self._updates_lock = Lock() + self._updates_available = Event() + self._updates = deque() + + def has_any(self): + """Returns True if a call to .pop_update() won't lock""" + return self._updates_available.is_set() + + def pop(self): + """Pops an update or blocks until an update object is available""" + self._updates_available.wait() + with self._updates_lock: + update = self._updates.popleft() + if not self._updates: + self._updates_available.clear() + + return update + + def process(self, update): + """Processes an update object. This method is normally called by + the library itself. + """ + if not self.enabled: + return + + with self._updates_lock: + self._updates.append(update) + self._updates_available.set()