From 65912f926bb91d995c0e9cc7b6faccc0d1e7e4ad Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sun, 11 Jun 2017 14:58:16 +0200 Subject: [PATCH] Allow to send more than one request before receiving them (#105) --- telethon/network/mtproto_sender.py | 110 +++++++++++++++-------------- 1 file changed, 56 insertions(+), 54 deletions(-) diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 87d536c2..95a64624 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -22,6 +22,7 @@ class MtProtoSender: self._logger = logging.getLogger(__name__) self._need_confirmation = [] # Message IDs that need confirmation + self._pending_receive = [] # Requests sent waiting to be received # Store an RLock instance to make this class safely multi-threaded self._lock = RLock() @@ -55,6 +56,7 @@ class MtProtoSender: with BinaryWriter() as writer: request.on_send(writer) self._send_packet(writer.get_bytes(), request) + self._pending_receive.append(request) # And update the saved session self.session.save() @@ -92,20 +94,19 @@ class MtProtoSender: self._logger.debug('receive() acquired the lock') # Don't stop trying to receive until we get the request we wanted # or, if there is no request, until we read an update - while True: + while (request and not request.confirm_received) or \ + (not request and not updates): self._logger.info('Trying to .receive() the request result...') seq, body = self._transport.receive(timeout) - message, remote_msg_id, remote_sequence = self._decode_msg(body) + message, remote_msg_id, remote_seq = self._decode_msg(body) with BinaryReader(message) as reader: - self._process_msg(remote_msg_id, remote_sequence, reader, - request, updates) + self._process_msg( + remote_msg_id, remote_seq, reader, updates) - if request is None: - if updates: - break # No request but one update read, exit - elif request.confirm_received: - break # Request, and result read, exit + # We're done receiving, remove the request from pending, if any + if request: + self._pending_receive.remove(request) self._logger.info('Request result received') self._logger.debug('receive() released the lock') @@ -182,7 +183,7 @@ class MtProtoSender: return message, remote_msg_id, remote_sequence - def _process_msg(self, msg_id, sequence, reader, request, updates): + def _process_msg(self, msg_id, sequence, reader, updates): """Processes and handles a Telegram message""" # TODO Check salt, session_id and sequence_number @@ -193,38 +194,33 @@ class MtProtoSender: # The following codes are "parsed manually" if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request) - return self._handle_rpc_result( - msg_id, sequence, reader, request) + return self._handle_rpc_result(msg_id, sequence, reader) if code == 0x347773c5: # pong - return self._handle_pong( - msg_id, sequence, reader, request) + return self._handle_pong(msg_id, sequence, reader) if code == 0x73f1f8dc: # msg_container - return self._handle_container( - msg_id, sequence, reader, request, updates) + return self._handle_container(msg_id, sequence, reader, updates) if code == 0x3072cfa1: # gzip_packed - return self._handle_gzip_packed( - msg_id, sequence, reader, request, updates) + return self._handle_gzip_packed(msg_id, sequence, reader, updates) if code == 0xedab447b: # bad_server_salt - return self._handle_bad_server_salt( - msg_id, sequence, reader, request) + return self._handle_bad_server_salt(msg_id, sequence, reader) if code == 0xa7eff811: # bad_msg_notification - return self._handle_bad_msg_notification( - msg_id, sequence, reader) + return self._handle_bad_msg_notification(msg_id, sequence, reader) # msgs_ack, it may handle the request we wanted if code == 0x62d6b459: ack = reader.tgread_object() - if request and request.msg_id in ack.msg_ids: - self._logger.warning('Ack found for the current request ID') + for r in self._pending_receive: + if r.msg_id in ack.msg_ids: + self._logger.warning('Ack found for the a request') - if self.logging_out: - self._logger.info('Message ack confirmed the logout request') - request.confirm_received = True + if self.logging_out: + self._logger.info('Message ack confirmed a request') + r.confirm_received = True return False @@ -248,18 +244,22 @@ class MtProtoSender: # region Message handling - def _handle_pong(self, msg_id, sequence, reader, request): + def _handle_pong(self, msg_id, sequence, reader): self._logger.debug('Handling pong') reader.read_int(signed=False) # code received_msg_id = reader.read_long(signed=False) - if received_msg_id == request.msg_id: + try: + request = next(r for r in self._pending_receive + if r.msg_id == received_msg_id) + self._logger.warning('Pong confirmed a request') request.confirm_received = True + except StopIteration: pass return False - def _handle_container(self, msg_id, sequence, reader, request, updates): + def _handle_container(self, msg_id, sequence, reader, updates): self._logger.debug('Handling container') reader.read_int(signed=False) # code size = reader.read_int() @@ -272,27 +272,26 @@ class MtProtoSender: # Note that this code is IMPORTANT for skipping RPC results of # lost requests (i.e., ones from the previous connection session) if not self._process_msg( - inner_msg_id, sequence, reader, request, updates): + inner_msg_id, sequence, reader, updates): reader.set_position(begin_position + inner_length) return False - def _handle_bad_server_salt(self, msg_id, sequence, reader, request): + def _handle_bad_server_salt(self, msg_id, sequence, reader): self._logger.debug('Handling bad server salt') reader.read_int(signed=False) # code - reader.read_long(signed=False) # bad_msg_id + bad_msg_id = reader.read_long(signed=False) reader.read_int() # bad_msg_seq_no reader.read_int() # error_code new_salt = reader.read_long(signed=False) - self.session.salt = new_salt - if request is None: - raise ValueError( - 'Tried to handle a bad server salt with no request specified') + try: + request = next(r for r in self._pending_receive + if r.msg_id == bad_msg_id) - # Resend - self.send(request) + self.send(request) + except StopIteration: pass return True @@ -314,14 +313,19 @@ class MtProtoSender: else: raise error - def _handle_rpc_result(self, msg_id, sequence, reader, request): - self._logger.debug('Handling RPC result, request is%s None', ' not' if request else '') + def _handle_rpc_result(self, msg_id, sequence, reader): + self._logger.debug('Handling RPC result') reader.read_int(signed=False) # code request_id = reader.read_long(signed=False) inner_code = reader.read_int(signed=False) - if request and request_id == request.msg_id: + try: + request = next(r for r in self._pending_receive + if r.msg_id == request_id) + request.confirm_received = True + except StopIteration: + request = None if inner_code == 0x2144ca19: # RPC Error error = rpc_message_to_error( @@ -333,13 +337,9 @@ class MtProtoSender: self._logger.warning('Read RPC error: %s', str(error)) if isinstance(error, InvalidDCError): - # Must resend this request - if not request: - raise ValueError( - 'The previously sent request must be resent. ' - 'However, no request was previously sent ' - '(possibly called from a different thread).') - request.confirm_received = False + # Must resend this request, if any + if request: + request.confirm_received = False raise error else: @@ -358,18 +358,20 @@ class MtProtoSender: if request_id == request.msg_id: request.on_response(reader) else: - # note: if it's really a result for RPC from previous connection + # If it's really a result for RPC from previous connection # session, it will be skipped by the handle_container() - self._logger.warning('RPC result found for unknown request (maybe from previous connection session)') + self._logger.warning( + 'RPC result found for unknown request ' + '(maybe from previous connection session)') - def _handle_gzip_packed(self, msg_id, sequence, reader, request, updates): + def _handle_gzip_packed(self, msg_id, sequence, reader, updates): self._logger.debug('Handling gzip packed data') reader.read_int(signed=False) # code packed_data = reader.tgread_bytes() unpacked_data = gzip.decompress(packed_data) with BinaryReader(unpacked_data) as compressed_reader: - return self._process_msg(msg_id, sequence, compressed_reader, - request, updates) + return self._process_msg( + msg_id, sequence, compressed_reader, updates) # endregion