Made MtProtoSender multi-thread safe

This will allow you to, for example, send requests
when you get an update (if someone tells you something,
you can automatically reply something else). Beware,
this may lead to an infinite loop so add conditions!
This commit is contained in:
Lonami 2016-09-10 11:01:03 +02:00
parent e47344c0f0
commit 72bc9df683
2 changed files with 44 additions and 37 deletions

View File

@ -74,8 +74,8 @@ if __name__ == '__main__':
date = datetime.fromtimestamp(msg.date) date = datetime.fromtimestamp(msg.date)
print('[{}:{}] {}: {}'.format(date.hour, date.minute, name, msg.message)) print('[{}:{}] {}: {}'.format(date.hour, date.minute, name, msg.message))
# Send chat message # Send chat message (if any)
else: elif msg:
client.send_message(input_peer, msg, markdown=True, no_web_page=True) client.send_message(input_peer, msg, markdown=True, no_web_page=True)
print('Thanks for trying the interactive example! Exiting...') print('Thanks for trying the interactive example! Exiting...')

View File

@ -3,7 +3,7 @@
import gzip import gzip
from errors import * from errors import *
from time import sleep from time import sleep
from threading import Thread from threading import Thread, Lock
import utils import utils
from crypto import AES from crypto import AES
@ -14,22 +14,20 @@ from tl.all_tlobjects import tlobjects
class MtProtoSender: class MtProtoSender:
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
def __init__(self, transport, session, check_updates_delay=0.1): def __init__(self, transport, session, check_updates=True):
"""If check_updates_delay is None, no updates will be checked.
Otherwise, specifies every how often updates should be checked"""
self.transport = transport self.transport = transport
self.session = session self.session = session
self.need_confirmation = [] # Message IDs that need confirmation self.need_confirmation = [] # Message IDs that need confirmation
self.on_update_handlers = [] self.on_update_handlers = []
# Set up updates thread, if the delay is not None # Store a Lock instance to make this class safely multi-threaded
self.check_updates_delay = check_updates_delay self.lock = Lock()
if check_updates_delay:
if check_updates:
self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread') self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread')
self.updates_thread_running = True self.updates_thread_running = True
self.updates_thread_paused = True self.updates_thread_receiving = False
self.updates_thread.start() self.updates_thread.start()
@ -60,14 +58,23 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation. This also pauses the updates thread""" which needed confirmation. This also pauses the updates thread"""
# Pause the updates thread: we cannot use self.transport at the same time! # Only cancel the receive *if* it was the
self.pause_updates_thread() # updates thread who was receiving. We do
# not want to cancel other pending requests!
if self.updates_thread_receiving:
self.transport.cancel_receive()
# Now only us can be using this method
self.lock.acquire()
# If any message needs confirmation send an AckRequest first # If any message needs confirmation send an AckRequest first
if self.need_confirmation: if self.need_confirmation:
msgs_ack = MsgsAck(self.need_confirmation[:]) msgs_ack = MsgsAck(self.need_confirmation)
with BinaryWriter() as writer:
msgs_ack.on_send(writer)
self.send_packet(writer.get_bytes(), msgs_ack)
del self.need_confirmation[:] del self.need_confirmation[:]
self.send(msgs_ack)
# Finally send our packed request # Finally send our packed request
with BinaryWriter() as writer: with BinaryWriter() as writer:
@ -83,23 +90,27 @@ class MtProtoSender:
"""Receives the specified MTProtoRequest ("fills in it" """Receives the specified MTProtoRequest ("fills in it"
the received data). This also restores the updates thread""" the received data). This also restores the updates thread"""
# Don't stop receiving until we get the request we wanted try:
while not request.confirm_received: # Don't stop trying to receive until we get the request we wanted
seq, body = self.transport.receive() while not request.confirm_received:
message, remote_msg_id, remote_sequence = self.decode_msg(body) seq, body = self.transport.receive()
message, remote_msg_id, remote_sequence = self.decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self.process_msg(remote_msg_id, remote_sequence, reader, request) self.process_msg(remote_msg_id, remote_sequence, reader, request)
# Once we have our request, restore the updates thread finally:
self.restore_updates_thread() # Once we are done trying to get our request,
# restore the updates thread and release the lock
self.lock.release()
# endregion # endregion
# region Low level processing # region Low level processing
def send_packet(self, packet, request): def send_packet(self, packet, request):
"""Sends the given packet bytes with the additional information of the original request""" """Sends the given packet bytes with the additional
information of the original request. This does NOT lock the threads!"""
request.msg_id = self.session.get_new_msg_id() request.msg_id = self.session.get_new_msg_id()
# First calculate plain_text to encrypt it # First calculate plain_text to encrypt it
@ -276,22 +287,12 @@ class MtProtoSender:
# endregion # endregion
def pause_updates_thread(self):
"""Pauses the updates thread and sleeps until it's safe to continue"""
if not self.updates_thread_paused:
self.updates_thread_paused = True
self.transport.cancel_receive()
def restore_updates_thread(self):
"""Restores the updates thread"""
self.updates_thread_paused = False
# TODO avoid, if possible using sleeps; Use thread locks instead
def updates_thread_method(self): def updates_thread_method(self):
"""This method will run until specified and listen for incoming updates""" """This method will run until specified and listen for incoming updates"""
while self.updates_thread_running: while self.updates_thread_running:
if not self.updates_thread_paused: with self.lock:
try: try:
self.updates_thread_receiving = True
seq, body = self.transport.receive() seq, body = self.transport.receive()
message, remote_msg_id, remote_sequence = self.decode_msg(body) message, remote_msg_id, remote_sequence = self.decode_msg(body)
@ -301,4 +302,10 @@ class MtProtoSender:
except ReadCancelledError: except ReadCancelledError:
pass pass
sleep(self.transport.get_client_delay()) self.updates_thread_receiving = False
# If we are here, it is because the read was cancelled
# Sleep a bit just to give enough time for the other thread
# to acquire the lock. No need to sleep if we're not running anymore
if self.updates_thread_running:
sleep(0.1)