Replace hardcoded reads with TLObject's .read()

This commit is contained in:
Lonami Exo 2017-10-12 16:40:59 +02:00
parent 3a4662c3bf
commit 0c1170ee61
2 changed files with 29 additions and 25 deletions

View File

@ -11,7 +11,7 @@ from ..errors import (
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck, Pong
from ..tl.types import MsgsAck, Pong, BadServerSalt, BadMsgNotification
from ..tl.functions.auth import LogOutRequest
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -180,24 +180,24 @@ class MtProtoSender:
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong
if code == Pong.CONSTRUCTOR_ID:
return self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
if code == MessageContainer.CONSTRUCTOR_ID:
return self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed
if code == GzipPacked.CONSTRUCTOR_ID:
return self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt
if code == BadServerSalt.CONSTRUCTOR_ID:
return self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification
if code == BadMsgNotification.CONSTRUCTOR_ID:
return self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
@ -219,7 +219,12 @@ class MtProtoSender:
return True
self._logger.debug('Unknown message: {}'.format(hex(code)))
self._logger.debug(
'[WARN] Unknown message: {}, data left in the buffer: {}'
.format(
hex(code), repr(reader.get_bytes()[reader.tell_position():])
)
)
return False
# endregion
@ -279,14 +284,15 @@ class MtProtoSender:
def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code
bad_msg_id = reader.read_long()
reader.read_int() # bad_msg_seq_no
reader.read_int() # error_code
new_salt = reader.read_long(signed=False)
self.session.salt = new_salt
bad_salt = reader.tgread_object()
assert isinstance(bad_salt, BadServerSalt)
request = self._pop_request(bad_msg_id)
# Our salt is unsigned, but the objects work with signed salts
self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0]
request = self._pop_request(bad_salt.bad_msg_id)
if request:
self.send(request)
@ -294,25 +300,23 @@ class MtProtoSender:
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code
reader.read_long() # request_id
reader.read_int() # request_sequence
bad_msg = reader.tgread_object()
assert isinstance(bad_msg, BadMsgNotification)
error_code = reader.read_int()
error = BadMessageError(error_code)
if error_code in (16, 17):
error = BadMessageError(bad_msg.error_code)
if bad_msg.error_code in (16, 17):
# sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
return True
elif error_code == 32:
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64
return True
elif error_code == 33:
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16
return True

View File

@ -34,5 +34,5 @@ class GzipPacked(TLObject):
@staticmethod
def read(reader):
reader.read_int(signed=False) # code
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
return gzip.decompress(reader.tgread_bytes())