From 72bc9df683b415ab9aef39732ceb3399dda96714 Mon Sep 17 00:00:00 2001 From: Lonami Date: Sat, 10 Sep 2016 11:01:03 +0200 Subject: [PATCH] Made MtProtoSender multi-thread safe This will allow you to, for example, send requests when you get an update (if someone tells you something, you can automatically reply something else). Beware, this may lead to an infinite loop so add conditions! --- main.py | 4 +- network/mtproto_sender.py | 77 +++++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/main.py b/main.py index 4cc390ec..0a18496b 100755 --- a/main.py +++ b/main.py @@ -74,8 +74,8 @@ if __name__ == '__main__': date = datetime.fromtimestamp(msg.date) print('[{}:{}] {}: {}'.format(date.hour, date.minute, name, msg.message)) - # Send chat message - else: + # Send chat message (if any) + elif msg: client.send_message(input_peer, msg, markdown=True, no_web_page=True) print('Thanks for trying the interactive example! Exiting...') diff --git a/network/mtproto_sender.py b/network/mtproto_sender.py index 5d42a537..cf9f663d 100755 --- a/network/mtproto_sender.py +++ b/network/mtproto_sender.py @@ -3,7 +3,7 @@ import gzip from errors import * from time import sleep -from threading import Thread +from threading import Thread, Lock import utils from crypto import AES @@ -14,22 +14,20 @@ from tl.all_tlobjects import tlobjects class MtProtoSender: """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" - def __init__(self, transport, session, check_updates_delay=0.1): - """If check_updates_delay is None, no updates will be checked. - Otherwise, specifies every how often updates should be checked""" - + def __init__(self, transport, session, check_updates=True): self.transport = transport self.session = session self.need_confirmation = [] # Message IDs that need confirmation self.on_update_handlers = [] - # Set up updates thread, if the delay is not None - self.check_updates_delay = check_updates_delay - if check_updates_delay: + # Store a Lock instance to make this class safely multi-threaded + self.lock = Lock() + + if check_updates: self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread') self.updates_thread_running = True - self.updates_thread_paused = True + self.updates_thread_receiving = False self.updates_thread.start() @@ -60,14 +58,23 @@ class MtProtoSender: """Sends the specified MTProtoRequest, previously sending any message which needed confirmation. This also pauses the updates thread""" - # Pause the updates thread: we cannot use self.transport at the same time! - self.pause_updates_thread() + # Only cancel the receive *if* it was the + # updates thread who was receiving. We do + # not want to cancel other pending requests! + if self.updates_thread_receiving: + self.transport.cancel_receive() + + # Now only us can be using this method + self.lock.acquire() # If any message needs confirmation send an AckRequest first if self.need_confirmation: - msgs_ack = MsgsAck(self.need_confirmation[:]) + msgs_ack = MsgsAck(self.need_confirmation) + with BinaryWriter() as writer: + msgs_ack.on_send(writer) + self.send_packet(writer.get_bytes(), msgs_ack) + del self.need_confirmation[:] - self.send(msgs_ack) # Finally send our packed request with BinaryWriter() as writer: @@ -83,23 +90,27 @@ class MtProtoSender: """Receives the specified MTProtoRequest ("fills in it" the received data). This also restores the updates thread""" - # Don't stop receiving until we get the request we wanted - while not request.confirm_received: - seq, body = self.transport.receive() - message, remote_msg_id, remote_sequence = self.decode_msg(body) + try: + # Don't stop trying to receive until we get the request we wanted + while not request.confirm_received: + seq, body = self.transport.receive() + message, remote_msg_id, remote_sequence = self.decode_msg(body) - with BinaryReader(message) as reader: - self.process_msg(remote_msg_id, remote_sequence, reader, request) + with BinaryReader(message) as reader: + self.process_msg(remote_msg_id, remote_sequence, reader, request) - # Once we have our request, restore the updates thread - self.restore_updates_thread() + finally: + # Once we are done trying to get our request, + # restore the updates thread and release the lock + self.lock.release() # endregion # region Low level processing def send_packet(self, packet, request): - """Sends the given packet bytes with the additional information of the original request""" + """Sends the given packet bytes with the additional + information of the original request. This does NOT lock the threads!""" request.msg_id = self.session.get_new_msg_id() # First calculate plain_text to encrypt it @@ -276,22 +287,12 @@ class MtProtoSender: # endregion - def pause_updates_thread(self): - """Pauses the updates thread and sleeps until it's safe to continue""" - if not self.updates_thread_paused: - self.updates_thread_paused = True - self.transport.cancel_receive() - - def restore_updates_thread(self): - """Restores the updates thread""" - self.updates_thread_paused = False - - # TODO avoid, if possible using sleeps; Use thread locks instead def updates_thread_method(self): """This method will run until specified and listen for incoming updates""" while self.updates_thread_running: - if not self.updates_thread_paused: + with self.lock: try: + self.updates_thread_receiving = True seq, body = self.transport.receive() message, remote_msg_id, remote_sequence = self.decode_msg(body) @@ -301,4 +302,10 @@ class MtProtoSender: except ReadCancelledError: pass - sleep(self.transport.get_client_delay()) + self.updates_thread_receiving = False + + # If we are here, it is because the read was cancelled + # Sleep a bit just to give enough time for the other thread + # to acquire the lock. No need to sleep if we're not running anymore + if self.updates_thread_running: + sleep(0.1)