Allow to send more than one request before receiving them (#105)

This commit is contained in:
Lonami Exo 2017-06-11 14:58:16 +02:00
parent c6acd6adc5
commit 65912f926b

View File

@ -22,6 +22,7 @@ class MtProtoSender:
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation self._need_confirmation = [] # Message IDs that need confirmation
self._pending_receive = [] # Requests sent waiting to be received
# Store an RLock instance to make this class safely multi-threaded # Store an RLock instance to make this class safely multi-threaded
self._lock = RLock() self._lock = RLock()
@ -55,6 +56,7 @@ class MtProtoSender:
with BinaryWriter() as writer: with BinaryWriter() as writer:
request.on_send(writer) request.on_send(writer)
self._send_packet(writer.get_bytes(), request) self._send_packet(writer.get_bytes(), request)
self._pending_receive.append(request)
# And update the saved session # And update the saved session
self.session.save() self.session.save()
@ -92,20 +94,19 @@ class MtProtoSender:
self._logger.debug('receive() acquired the lock') self._logger.debug('receive() acquired the lock')
# Don't stop trying to receive until we get the request we wanted # Don't stop trying to receive until we get the request we wanted
# or, if there is no request, until we read an update # or, if there is no request, until we read an update
while True: while (request and not request.confirm_received) or \
(not request and not updates):
self._logger.info('Trying to .receive() the request result...') self._logger.info('Trying to .receive() the request result...')
seq, body = self._transport.receive(timeout) seq, body = self._transport.receive(timeout)
message, remote_msg_id, remote_sequence = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_sequence, reader, self._process_msg(
request, updates) remote_msg_id, remote_seq, reader, updates)
if request is None: # We're done receiving, remove the request from pending, if any
if updates: if request:
break # No request but one update read, exit self._pending_receive.remove(request)
elif request.confirm_received:
break # Request, and result read, exit
self._logger.info('Request result received') self._logger.info('Request result received')
self._logger.debug('receive() released the lock') self._logger.debug('receive() released the lock')
@ -182,7 +183,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, request, updates): def _process_msg(self, msg_id, sequence, reader, updates):
"""Processes and handles a Telegram message""" """Processes and handles a Telegram message"""
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
@ -193,38 +194,33 @@ class MtProtoSender:
# The following codes are "parsed manually" # The following codes are "parsed manually"
if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request) if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request)
return self._handle_rpc_result( return self._handle_rpc_result(msg_id, sequence, reader)
msg_id, sequence, reader, request)
if code == 0x347773c5: # pong if code == 0x347773c5: # pong
return self._handle_pong( return self._handle_pong(msg_id, sequence, reader)
msg_id, sequence, reader, request)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self._handle_container( return self._handle_container(msg_id, sequence, reader, updates)
msg_id, sequence, reader, request, updates)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed( return self._handle_gzip_packed(msg_id, sequence, reader, updates)
msg_id, sequence, reader, request, updates)
if code == 0xedab447b: # bad_server_salt if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt( return self._handle_bad_server_salt(msg_id, sequence, reader)
msg_id, sequence, reader, request)
if code == 0xa7eff811: # bad_msg_notification if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification( return self._handle_bad_msg_notification(msg_id, sequence, reader)
msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted # msgs_ack, it may handle the request we wanted
if code == 0x62d6b459: if code == 0x62d6b459:
ack = reader.tgread_object() ack = reader.tgread_object()
if request and request.msg_id in ack.msg_ids: for r in self._pending_receive:
self._logger.warning('Ack found for the current request ID') if r.msg_id in ack.msg_ids:
self._logger.warning('Ack found for the a request')
if self.logging_out: if self.logging_out:
self._logger.info('Message ack confirmed the logout request') self._logger.info('Message ack confirmed a request')
request.confirm_received = True r.confirm_received = True
return False return False
@ -248,18 +244,22 @@ class MtProtoSender:
# region Message handling # region Message handling
def _handle_pong(self, msg_id, sequence, reader, request): def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
received_msg_id = reader.read_long(signed=False) received_msg_id = reader.read_long(signed=False)
if received_msg_id == request.msg_id: try:
request = next(r for r in self._pending_receive
if r.msg_id == received_msg_id)
self._logger.warning('Pong confirmed a request') self._logger.warning('Pong confirmed a request')
request.confirm_received = True request.confirm_received = True
except StopIteration: pass
return False return False
def _handle_container(self, msg_id, sequence, reader, request, updates): def _handle_container(self, msg_id, sequence, reader, updates):
self._logger.debug('Handling container') self._logger.debug('Handling container')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
size = reader.read_int() size = reader.read_int()
@ -272,27 +272,26 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
if not self._process_msg( if not self._process_msg(
inner_msg_id, sequence, reader, request, updates): inner_msg_id, sequence, reader, updates):
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_length)
return False return False
def _handle_bad_server_salt(self, msg_id, sequence, reader, request): def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt') self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
reader.read_long(signed=False) # bad_msg_id bad_msg_id = reader.read_long(signed=False)
reader.read_int() # bad_msg_seq_no reader.read_int() # bad_msg_seq_no
reader.read_int() # error_code reader.read_int() # error_code
new_salt = reader.read_long(signed=False) new_salt = reader.read_long(signed=False)
self.session.salt = new_salt self.session.salt = new_salt
if request is None: try:
raise ValueError( request = next(r for r in self._pending_receive
'Tried to handle a bad server salt with no request specified') if r.msg_id == bad_msg_id)
# Resend
self.send(request) self.send(request)
except StopIteration: pass
return True return True
@ -314,14 +313,19 @@ class MtProtoSender:
else: else:
raise error raise error
def _handle_rpc_result(self, msg_id, sequence, reader, request): def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result, request is%s None', ' not' if request else '') self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
request_id = reader.read_long(signed=False) request_id = reader.read_long(signed=False)
inner_code = reader.read_int(signed=False) inner_code = reader.read_int(signed=False)
if request and request_id == request.msg_id: try:
request = next(r for r in self._pending_receive
if r.msg_id == request_id)
request.confirm_received = True request.confirm_received = True
except StopIteration:
request = None
if inner_code == 0x2144ca19: # RPC Error if inner_code == 0x2144ca19: # RPC Error
error = rpc_message_to_error( error = rpc_message_to_error(
@ -333,12 +337,8 @@ class MtProtoSender:
self._logger.warning('Read RPC error: %s', str(error)) self._logger.warning('Read RPC error: %s', str(error))
if isinstance(error, InvalidDCError): if isinstance(error, InvalidDCError):
# Must resend this request # Must resend this request, if any
if not request: if request:
raise ValueError(
'The previously sent request must be resent. '
'However, no request was previously sent '
'(possibly called from a different thread).')
request.confirm_received = False request.confirm_received = False
raise error raise error
@ -358,18 +358,20 @@ class MtProtoSender:
if request_id == request.msg_id: if request_id == request.msg_id:
request.on_response(reader) request.on_response(reader)
else: else:
# note: if it's really a result for RPC from previous connection # If it's really a result for RPC from previous connection
# session, it will be skipped by the handle_container() # session, it will be skipped by the handle_container()
self._logger.warning('RPC result found for unknown request (maybe from previous connection session)') self._logger.warning(
'RPC result found for unknown request '
'(maybe from previous connection session)')
def _handle_gzip_packed(self, msg_id, sequence, reader, request, updates): def _handle_gzip_packed(self, msg_id, sequence, reader, updates):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes() packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data) unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, return self._process_msg(
request, updates) msg_id, sequence, compressed_reader, updates)
# endregion # endregion