Create and use UpdateState to .process() unhandled TLObjects

This commit is contained in:
Lonami Exo 2017-09-07 18:49:08 +02:00
parent 49e884b005
commit d4f36162cd
4 changed files with 79 additions and 41 deletions

View File

@ -35,16 +35,6 @@ class MtProtoSender:
# TODO There might be a better way to handle msgs_ack requests # TODO There might be a better way to handle msgs_ack requests
self.logging_out = False self.logging_out = False
# Every unhandled result gets passed to these callbacks, which
# should be functions accepting a single parameter: a TLObject.
# This should only be Update(s), although it can actually be any type.
#
# The thread from which these callbacks are called can be any.
#
# The creator of the MtProtoSender is responsible for setting this
# to point to the list wherever their callbacks reside.
self.unhandled_callbacks = None
def connect(self): def connect(self):
"""Connects to the server""" """Connects to the server"""
self.connection.connect() self.connection.connect()
@ -90,12 +80,15 @@ class MtProtoSender:
del self._need_confirmation[:] del self._need_confirmation[:]
def receive(self): def receive(self, update_state):
"""Receives a single message from the connected endpoint. """Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts This method returns nothing, and will only affect other parts
of the MtProtoSender such as the updates callback being fired of the MtProtoSender such as the updates callback being fired
or a pending request being confirmed. or a pending request being confirmed.
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
""" """
# TODO Don't ignore updates # TODO Don't ignore updates
self._logger.debug('Receiving a message...') self._logger.debug('Receiving a message...')
@ -103,8 +96,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg( self._process_msg(remote_msg_id, remote_seq, reader, update_state)
remote_msg_id, remote_seq, reader, updates=None)
self._logger.debug('Received message.') self._logger.debug('Received message.')
@ -172,7 +164,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, updates): def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message. """Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't Returns True if the message was handled correctly and doesn't
@ -193,10 +185,10 @@ class MtProtoSender:
return self._handle_pong(msg_id, sequence, reader) return self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, updates) return self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, updates) return self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader) return self._handle_bad_server_salt(msg_id, sequence, reader)
@ -221,16 +213,15 @@ class MtProtoSender:
# If the code is not parsed manually then it should be a TLObject. # If the code is not parsed manually then it should be a TLObject.
if code in tlobjects: if code in tlobjects:
result = reader.tgread_object() result = reader.tgread_object()
if self.unhandled_callbacks: if state is None:
self._logger.debug(
'Passing TLObject to callbacks %s', repr(result)
)
for callback in self.unhandled_callbacks:
callback(result)
else:
self._logger.debug( self._logger.debug(
'Ignoring unhandled TLObject %s', repr(result) 'Ignoring unhandled TLObject %s', repr(result)
) )
else:
self._logger.debug(
'Processing TLObject %s', repr(result)
)
state.process(result)
return True return True
@ -261,7 +252,7 @@ class MtProtoSender:
return True return True
def _handle_container(self, msg_id, sequence, reader, updates): def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container') self._logger.debug('Handling container')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
size = reader.read_int() size = reader.read_int()
@ -274,8 +265,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not self._process_msg( if not self._process_msg(inner_msg_id, sequence, reader, state):
inner_msg_id, sequence, reader, updates):
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_length)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
@ -366,14 +356,13 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.') self._logger.debug('Lost request will be skipped.')
return False return False
def _handle_gzip_packed(self, msg_id, sequence, reader, updates): def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes() packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data) unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg( return self._process_msg(msg_id, sequence, compressed_reader, state)
msg_id, sequence, compressed_reader, updates)
# endregion # endregion

View File

@ -26,6 +26,7 @@ from .tl.functions.upload import (
) )
from .tl.types import InputFile, InputFileBig from .tl.types import InputFile, InputFileBig
from .tl.types.upload import FileCdnRedirect from .tl.types.upload import FileCdnRedirect
from .update_state import UpdateState
from .utils import get_appropriated_part_size from .utils import get_appropriated_part_size
@ -56,7 +57,9 @@ class TelegramBareClient:
def __init__(self, session, api_id, api_hash, def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)): proxy=None,
enable_updates=False,
timeout=timedelta(seconds=5)):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
Session must always be a Session instance, and an optional proxy Session must always be a Session instance, and an optional proxy
can also be specified to be used on the connection. can also be specified to be used on the connection.
@ -74,11 +77,9 @@ class TelegramBareClient:
# the time since it's a (somewhat expensive) process. # the time since it's a (somewhat expensive) process.
self._cached_clients = {} self._cached_clients = {}
# Update callbacks (functions accepting a single TLObject) go here # This member will process updates if enabled.
# # One may change self.updates.enabled at any later point.
# Note that changing the list to which this variable points to self.updates = UpdateState(enabled=enable_updates)
# will not reflect the changes on the existing senders.
self._update_callbacks = []
# These will be set later # These will be set later
self.dc_options = None self.dc_options = None
@ -127,7 +128,6 @@ class TelegramBareClient:
self.session.save() self.session.save()
self._sender = MtProtoSender(connection, self.session) self._sender = MtProtoSender(connection, self.session)
self._sender.unhandled_callbacks = self._update_callbacks
self._sender.connect() self._sender.connect()
# Now it's time to send an InitConnectionRequest # Now it's time to send an InitConnectionRequest
@ -312,7 +312,7 @@ class TelegramBareClient:
request.confirm_received.wait() # TODO Socket's timeout here? request.confirm_received.wait() # TODO Socket's timeout here?
else: else:
while not request.confirm_received.is_set(): while not request.confirm_received.is_set():
self._sender.receive() self._sender.receive(update_state=self.updates)
if request.rpc_error: if request.rpc_error:
raise request.rpc_error raise request.rpc_error

View File

@ -56,6 +56,7 @@ class TelegramClient(TelegramBareClient):
def __init__(self, session, api_id, api_hash, def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
enable_updates=False,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
**kwargs): **kwargs):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
@ -69,6 +70,12 @@ class TelegramClient(TelegramBareClient):
This will only affect how messages are sent over the network This will only affect how messages are sent over the network
and how much processing is required before sending them. and how much processing is required before sending them.
If 'enable_updates' is set to True, it will by default put
all updates on self.updates. NOTE that you must manually query
this from another thread or it will eventually fill up all your
memory. If you want to ignore updates, leave this set to False.
You may change self.updates.enabled at any later point.
If more named arguments are provided as **kwargs, they will be If more named arguments are provided as **kwargs, they will be
used to update the Session instance. Most common settings are: used to update the Session instance. Most common settings are:
device_model = platform.node() device_model = platform.node()
@ -92,7 +99,10 @@ class TelegramClient(TelegramBareClient):
super().__init__( super().__init__(
session, api_id, api_hash, session, api_id, api_hash,
connection_mode=connection_mode, proxy=proxy, timeout=timeout connection_mode=connection_mode,
proxy=proxy,
enable_updates=enable_updates,
timeout=timeout
) )
# Safety across multiple threads (for the updates thread) # Safety across multiple threads (for the updates thread)
@ -407,8 +417,6 @@ class TelegramClient(TelegramBareClient):
no_webpage=not link_preview no_webpage=not link_preview
) )
result = self(request) result = self(request)
for callback in self._update_callbacks:
callback(result)
return request.random_id return request.random_id
def get_message_history(self, def get_message_history(self,
@ -899,12 +907,15 @@ class TelegramClient(TelegramBareClient):
def add_update_handler(self, handler): def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
return # TODO Implement
self._update_callbacks.append(handler) self._update_callbacks.append(handler)
def remove_update_handler(self, handler): def remove_update_handler(self, handler):
return # TODO Implement
self._update_callbacks.remove(handler) self._update_callbacks.remove(handler)
def list_update_handlers(self): def list_update_handlers(self):
return # TODO Implement
return self._update_callbacks[:] return self._update_callbacks[:]
# endregion # endregion
@ -921,7 +932,7 @@ class TelegramClient(TelegramBareClient):
def _recv_thread_impl(self): def _recv_thread_impl(self):
while self._sender and self._sender.is_connected(): while self._sender and self._sender.is_connected():
try: try:
self._sender.receive() self._sender.receive(update_state=self.updates)
except TimeoutError: except TimeoutError:
# No problem. # No problem.
pass pass

38
telethon/update_state.py Normal file
View File

@ -0,0 +1,38 @@
from threading import Lock, Event
from collections import deque
class UpdateState:
"""Used to hold the current state of processed updates.
To retrieve an update, .pop_update() should be called.
"""
def __init__(self, enabled):
self.enabled = enabled
self._updates_lock = Lock()
self._updates_available = Event()
self._updates = deque()
def has_any(self):
"""Returns True if a call to .pop_update() won't lock"""
return self._updates_available.is_set()
def pop(self):
"""Pops an update or blocks until an update object is available"""
self._updates_available.wait()
with self._updates_lock:
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
return update
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if not self.enabled:
return
with self._updates_lock:
self._updates.append(update)
self._updates_available.set()