mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-24 08:14:14 +03:00
Create and use UpdateState to .process() unhandled TLObjects
This commit is contained in:
parent
49e884b005
commit
d4f36162cd
|
@ -35,16 +35,6 @@ class MtProtoSender:
|
|||
# TODO There might be a better way to handle msgs_ack requests
|
||||
self.logging_out = False
|
||||
|
||||
# Every unhandled result gets passed to these callbacks, which
|
||||
# should be functions accepting a single parameter: a TLObject.
|
||||
# This should only be Update(s), although it can actually be any type.
|
||||
#
|
||||
# The thread from which these callbacks are called can be any.
|
||||
#
|
||||
# The creator of the MtProtoSender is responsible for setting this
|
||||
# to point to the list wherever their callbacks reside.
|
||||
self.unhandled_callbacks = None
|
||||
|
||||
def connect(self):
|
||||
"""Connects to the server"""
|
||||
self.connection.connect()
|
||||
|
@ -90,12 +80,15 @@ class MtProtoSender:
|
|||
|
||||
del self._need_confirmation[:]
|
||||
|
||||
def receive(self):
|
||||
def receive(self, update_state):
|
||||
"""Receives a single message from the connected endpoint.
|
||||
|
||||
This method returns nothing, and will only affect other parts
|
||||
of the MtProtoSender such as the updates callback being fired
|
||||
or a pending request being confirmed.
|
||||
|
||||
Any unhandled object (likely updates) will be passed to
|
||||
update_state.process(TLObject).
|
||||
"""
|
||||
# TODO Don't ignore updates
|
||||
self._logger.debug('Receiving a message...')
|
||||
|
@ -103,8 +96,7 @@ class MtProtoSender:
|
|||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
|
||||
with BinaryReader(message) as reader:
|
||||
self._process_msg(
|
||||
remote_msg_id, remote_seq, reader, updates=None)
|
||||
self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
|
||||
self._logger.debug('Received message.')
|
||||
|
||||
|
@ -172,7 +164,7 @@ class MtProtoSender:
|
|||
|
||||
return message, remote_msg_id, remote_sequence
|
||||
|
||||
def _process_msg(self, msg_id, sequence, reader, updates):
|
||||
def _process_msg(self, msg_id, sequence, reader, state):
|
||||
"""Processes and handles a Telegram message.
|
||||
|
||||
Returns True if the message was handled correctly and doesn't
|
||||
|
@ -193,10 +185,10 @@ class MtProtoSender:
|
|||
return self._handle_pong(msg_id, sequence, reader)
|
||||
|
||||
if code == 0x73f1f8dc: # msg_container
|
||||
return self._handle_container(msg_id, sequence, reader, updates)
|
||||
return self._handle_container(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0x3072cfa1: # gzip_packed
|
||||
return self._handle_gzip_packed(msg_id, sequence, reader, updates)
|
||||
return self._handle_gzip_packed(msg_id, sequence, reader, state)
|
||||
|
||||
if code == 0xedab447b: # bad_server_salt
|
||||
return self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||
|
@ -221,16 +213,15 @@ class MtProtoSender:
|
|||
# If the code is not parsed manually then it should be a TLObject.
|
||||
if code in tlobjects:
|
||||
result = reader.tgread_object()
|
||||
if self.unhandled_callbacks:
|
||||
self._logger.debug(
|
||||
'Passing TLObject to callbacks %s', repr(result)
|
||||
)
|
||||
for callback in self.unhandled_callbacks:
|
||||
callback(result)
|
||||
else:
|
||||
if state is None:
|
||||
self._logger.debug(
|
||||
'Ignoring unhandled TLObject %s', repr(result)
|
||||
)
|
||||
else:
|
||||
self._logger.debug(
|
||||
'Processing TLObject %s', repr(result)
|
||||
)
|
||||
state.process(result)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -261,7 +252,7 @@ class MtProtoSender:
|
|||
|
||||
return True
|
||||
|
||||
def _handle_container(self, msg_id, sequence, reader, updates):
|
||||
def _handle_container(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling container')
|
||||
reader.read_int(signed=False) # code
|
||||
size = reader.read_int()
|
||||
|
@ -274,8 +265,7 @@ class MtProtoSender:
|
|||
# Note that this code is IMPORTANT for skipping RPC results of
|
||||
# lost requests (i.e., ones from the previous connection session)
|
||||
try:
|
||||
if not self._process_msg(
|
||||
inner_msg_id, sequence, reader, updates):
|
||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
reader.set_position(begin_position + inner_length)
|
||||
except:
|
||||
# If any error is raised, something went wrong; skip the packet
|
||||
|
@ -366,14 +356,13 @@ class MtProtoSender:
|
|||
self._logger.debug('Lost request will be skipped.')
|
||||
return False
|
||||
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, updates):
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling gzip packed data')
|
||||
reader.read_int(signed=False) # code
|
||||
packed_data = reader.tgread_bytes()
|
||||
unpacked_data = gzip.decompress(packed_data)
|
||||
|
||||
with BinaryReader(unpacked_data) as compressed_reader:
|
||||
return self._process_msg(
|
||||
msg_id, sequence, compressed_reader, updates)
|
||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -26,6 +26,7 @@ from .tl.functions.upload import (
|
|||
)
|
||||
from .tl.types import InputFile, InputFileBig
|
||||
from .tl.types.upload import FileCdnRedirect
|
||||
from .update_state import UpdateState
|
||||
from .utils import get_appropriated_part_size
|
||||
|
||||
|
||||
|
@ -56,7 +57,9 @@ class TelegramBareClient:
|
|||
|
||||
def __init__(self, session, api_id, api_hash,
|
||||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None, timeout=timedelta(seconds=5)):
|
||||
proxy=None,
|
||||
enable_updates=False,
|
||||
timeout=timedelta(seconds=5)):
|
||||
"""Initializes the Telegram client with the specified API ID and Hash.
|
||||
Session must always be a Session instance, and an optional proxy
|
||||
can also be specified to be used on the connection.
|
||||
|
@ -74,11 +77,9 @@ class TelegramBareClient:
|
|||
# the time since it's a (somewhat expensive) process.
|
||||
self._cached_clients = {}
|
||||
|
||||
# Update callbacks (functions accepting a single TLObject) go here
|
||||
#
|
||||
# Note that changing the list to which this variable points to
|
||||
# will not reflect the changes on the existing senders.
|
||||
self._update_callbacks = []
|
||||
# This member will process updates if enabled.
|
||||
# One may change self.updates.enabled at any later point.
|
||||
self.updates = UpdateState(enabled=enable_updates)
|
||||
|
||||
# These will be set later
|
||||
self.dc_options = None
|
||||
|
@ -127,7 +128,6 @@ class TelegramBareClient:
|
|||
self.session.save()
|
||||
|
||||
self._sender = MtProtoSender(connection, self.session)
|
||||
self._sender.unhandled_callbacks = self._update_callbacks
|
||||
self._sender.connect()
|
||||
|
||||
# Now it's time to send an InitConnectionRequest
|
||||
|
@ -312,7 +312,7 @@ class TelegramBareClient:
|
|||
request.confirm_received.wait() # TODO Socket's timeout here?
|
||||
else:
|
||||
while not request.confirm_received.is_set():
|
||||
self._sender.receive()
|
||||
self._sender.receive(update_state=self.updates)
|
||||
|
||||
if request.rpc_error:
|
||||
raise request.rpc_error
|
||||
|
|
|
@ -56,6 +56,7 @@ class TelegramClient(TelegramBareClient):
|
|||
def __init__(self, session, api_id, api_hash,
|
||||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None,
|
||||
enable_updates=False,
|
||||
timeout=timedelta(seconds=5),
|
||||
**kwargs):
|
||||
"""Initializes the Telegram client with the specified API ID and Hash.
|
||||
|
@ -69,6 +70,12 @@ class TelegramClient(TelegramBareClient):
|
|||
This will only affect how messages are sent over the network
|
||||
and how much processing is required before sending them.
|
||||
|
||||
If 'enable_updates' is set to True, it will by default put
|
||||
all updates on self.updates. NOTE that you must manually query
|
||||
this from another thread or it will eventually fill up all your
|
||||
memory. If you want to ignore updates, leave this set to False.
|
||||
You may change self.updates.enabled at any later point.
|
||||
|
||||
If more named arguments are provided as **kwargs, they will be
|
||||
used to update the Session instance. Most common settings are:
|
||||
device_model = platform.node()
|
||||
|
@ -92,7 +99,10 @@ class TelegramClient(TelegramBareClient):
|
|||
|
||||
super().__init__(
|
||||
session, api_id, api_hash,
|
||||
connection_mode=connection_mode, proxy=proxy, timeout=timeout
|
||||
connection_mode=connection_mode,
|
||||
proxy=proxy,
|
||||
enable_updates=enable_updates,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Safety across multiple threads (for the updates thread)
|
||||
|
@ -407,8 +417,6 @@ class TelegramClient(TelegramBareClient):
|
|||
no_webpage=not link_preview
|
||||
)
|
||||
result = self(request)
|
||||
for callback in self._update_callbacks:
|
||||
callback(result)
|
||||
return request.random_id
|
||||
|
||||
def get_message_history(self,
|
||||
|
@ -899,12 +907,15 @@ class TelegramClient(TelegramBareClient):
|
|||
def add_update_handler(self, handler):
|
||||
"""Adds an update handler (a function which takes a TLObject,
|
||||
an update, as its parameter) and listens for updates"""
|
||||
return # TODO Implement
|
||||
self._update_callbacks.append(handler)
|
||||
|
||||
def remove_update_handler(self, handler):
|
||||
return # TODO Implement
|
||||
self._update_callbacks.remove(handler)
|
||||
|
||||
def list_update_handlers(self):
|
||||
return # TODO Implement
|
||||
return self._update_callbacks[:]
|
||||
|
||||
# endregion
|
||||
|
@ -921,7 +932,7 @@ class TelegramClient(TelegramBareClient):
|
|||
def _recv_thread_impl(self):
|
||||
while self._sender and self._sender.is_connected():
|
||||
try:
|
||||
self._sender.receive()
|
||||
self._sender.receive(update_state=self.updates)
|
||||
except TimeoutError:
|
||||
# No problem.
|
||||
pass
|
||||
|
|
38
telethon/update_state.py
Normal file
38
telethon/update_state.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from threading import Lock, Event
|
||||
from collections import deque
|
||||
|
||||
|
||||
class UpdateState:
|
||||
"""Used to hold the current state of processed updates.
|
||||
To retrieve an update, .pop_update() should be called.
|
||||
"""
|
||||
def __init__(self, enabled):
|
||||
self.enabled = enabled
|
||||
self._updates_lock = Lock()
|
||||
self._updates_available = Event()
|
||||
self._updates = deque()
|
||||
|
||||
def has_any(self):
|
||||
"""Returns True if a call to .pop_update() won't lock"""
|
||||
return self._updates_available.is_set()
|
||||
|
||||
def pop(self):
|
||||
"""Pops an update or blocks until an update object is available"""
|
||||
self._updates_available.wait()
|
||||
with self._updates_lock:
|
||||
update = self._updates.popleft()
|
||||
if not self._updates:
|
||||
self._updates_available.clear()
|
||||
|
||||
return update
|
||||
|
||||
def process(self, update):
|
||||
"""Processes an update object. This method is normally called by
|
||||
the library itself.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
with self._updates_lock:
|
||||
self._updates.append(update)
|
||||
self._updates_available.set()
|
Loading…
Reference in New Issue
Block a user