From be279ce3f51437fc824b0de223fe5697450dd4ad Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sat, 9 Jun 2018 13:48:27 +0200 Subject: [PATCH] Make TLMessage always have a valid TLObject This simplifies the flow instead of having separate request/body attributes, and also means that BinaryReader.tgread_object() can be used without so many special cases. --- telethon/errors/common.py | 9 +-- telethon/extensions/binary_reader.py | 5 +- telethon/network/mtprotosender.py | 83 ++++++++++++++-------------- telethon/network/mtprotostate.py | 39 ++++++++----- telethon/tl/core/messagecontainer.py | 34 ++++++------ telethon/tl/core/tlmessage.py | 19 ++----- 6 files changed, 98 insertions(+), 91 deletions(-) diff --git a/telethon/errors/common.py b/telethon/errors/common.py index 0c03aee6..f8b479e7 100644 --- a/telethon/errors/common.py +++ b/telethon/errors/common.py @@ -12,14 +12,15 @@ class TypeNotFoundError(Exception): Occurs when a type is not found, for example, when trying to read a TLObject with an invalid constructor code. """ - def __init__(self, invalid_constructor_id): + def __init__(self, invalid_constructor_id, remaining): super().__init__( 'Could not find a matching Constructor ID for the TLObject ' - 'that was supposed to be read with ID {}. Most likely, a TLObject ' - 'was trying to be read when it should not be read.' - .format(hex(invalid_constructor_id))) + 'that was supposed to be read with ID {:08x}. Most likely, ' + 'a TLObject was trying to be read when it should not be read. ' + 'Remaining bytes: {!r}'.format(invalid_constructor_id, remaining)) self.invalid_constructor_id = invalid_constructor_id + self.remaining = remaining class InvalidChecksumError(Exception): diff --git a/telethon/extensions/binary_reader.py b/telethon/extensions/binary_reader.py index e7496d77..b0027084 100644 --- a/telethon/extensions/binary_reader.py +++ b/telethon/extensions/binary_reader.py @@ -141,7 +141,10 @@ class BinaryReader: if clazz is None: # If there was still no luck, give up self.seek(-4) # Go back - raise TypeNotFoundError(constructor_id) + pos = self.tell_position() + error = TypeNotFoundError(constructor_id, self.read()) + self.set_position(pos) + raise error return clazz.from_reader(self) diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 28369b00..27b95f8d 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -351,29 +351,27 @@ class MTProtoSender: __log__.warning('Security error while unpacking a ' 'received message:'.format(e)) continue + except TypeNotFoundError as e: + # The payload inside the message was not a known TLObject. + __log__.info('Server replied with an unknown type {:08x}: {!r}' + .format(e.invalid_constructor_id, e.remaining)) else: - try: - with BinaryReader(message.body) as reader: - obj = reader.tgread_object() - except TypeNotFoundError as e: - __log__.warning('Could not decode received message: {}, ' - 'raw bytes: {!r}'.format(e, message)) - else: - await self._process_message(message, obj) + await self._process_message(message) # Response Handlers - async def _process_message(self, message, obj): + async def _process_message(self, message): """ Adds the given message to the list of messages that must be acknowledged and dispatches control to different ``_handle_*`` method based on its type. """ self._pending_ack.add(message.msg_id) - handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update) - await handler(message, obj) + handler = self._handlers.get(message.obj.CONSTRUCTOR_ID, + self._handle_update) + await handler(message) - async def _handle_rpc_result(self, message, rpc_result): + async def _handle_rpc_result(self, message): """ Handles the result for Remote Procedure Calls: @@ -381,6 +379,7 @@ class MTProtoSender: This is where the future results for sent requests are set. """ + rpc_result = message.obj message = self._pending_messages.pop(rpc_result.req_msg_id, None) __log__.debug('Handling RPC result for message {}' .format(rpc_result.req_msg_id)) @@ -397,7 +396,7 @@ class MTProtoSender: return elif message: with BinaryReader(rpc_result.body) as reader: - result = message.request.read_result(reader) + result = message.obj.read_result(reader) # TODO Process entities if not message.future.cancelled(): @@ -408,35 +407,35 @@ class MTProtoSender: __log__.info('Received response without parent request: {}' .format(rpc_result.body)) - async def _handle_container(self, message, container): + async def _handle_container(self, message): """ Processes the inner messages of a container with many of them: msg_container#73f1f8dc messages:vector<%Message> = MessageContainer; """ __log__.debug('Handling container') - for inner_message in container.messages: - with BinaryReader(inner_message.body) as reader: - inner_obj = reader.tgread_object() - await self._process_message(inner_message, inner_obj) + for inner_message in message.obj.messages: + await self._process_message(inner_message) - async def _handle_gzip_packed(self, message, gzip_packed): + async def _handle_gzip_packed(self, message): """ Unpacks the data from a gzipped object and processes it: gzip_packed#3072cfa1 packed_data:bytes = Object; """ __log__.debug('Handling gzipped data') - with BinaryReader(gzip_packed.data) as reader: - await self._process_message(message, reader.tgread_object()) + with BinaryReader(message.obj.data) as reader: + message.obj = reader.tgread_object() + await self._process_message(message) - async def _handle_update(self, message, update): - __log__.debug('Handling update {}'.format(update.__class__.__name__)) + async def _handle_update(self, message): + __log__.debug('Handling update {}' + .format(message.obj.__class__.__name__)) # TODO Further handling of the update # TODO Process entities - async def _handle_pong(self, message, pong): + async def _handle_pong(self, message): """ Handles pong results, which don't come inside a ``rpc_result`` but are still sent through a request: @@ -444,11 +443,12 @@ class MTProtoSender: pong#347773c5 msg_id:long ping_id:long = Pong; """ __log__.debug('Handling pong') + pong = message.obj message = self._pending_messages.pop(pong.msg_id, None) if message: - message.future.set_result(pong) + message.future.set_result(pong.obj) - async def _handle_bad_server_salt(self, message, bad_salt): + async def _handle_bad_server_salt(self, message): """ Corrects the currently used server salt to use the right value before enqueuing the rejected message to be re-sent: @@ -457,10 +457,11 @@ class MTProtoSender: error_code:int new_server_salt:long = BadMsgNotification; """ __log__.debug('Handling bad salt') + bad_salt = message.obj self.state.salt = bad_salt.new_server_salt await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id]) - async def _handle_bad_notification(self, message, bad_msg): + async def _handle_bad_notification(self, message): """ Adjusts the current state to be correct based on the received bad message notification whenever possible: @@ -469,6 +470,7 @@ class MTProtoSender: error_code:int = BadMsgNotification; """ __log__.debug('Handling bad message') + bad_msg = message.obj if bad_msg.error_code in (16, 17): # Sent msg_id too low or too high (respectively). # Use the current msg_id to determine the right time offset. @@ -489,7 +491,7 @@ class MTProtoSender: # Messages are to be re-sent once we've corrected the issue await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id]) - async def _handle_detailed_info(self, message, detailed_info): + async def _handle_detailed_info(self, message): """ Updates the current status with the received detailed information: @@ -498,9 +500,9 @@ class MTProtoSender: """ # TODO https://goo.gl/VvpCC6 __log__.debug('Handling detailed info') - self._pending_ack.add(detailed_info.answer_msg_id) + self._pending_ack.add(message.obj.answer_msg_id) - async def _handle_new_detailed_info(self, message, new_detailed_info): + async def _handle_new_detailed_info(self, message): """ Updates the current status with the received detailed information: @@ -509,9 +511,9 @@ class MTProtoSender: """ # TODO https://goo.gl/G7DPsR __log__.debug('Handling new detailed info') - self._pending_ack.add(new_detailed_info.answer_msg_id) + self._pending_ack.add(message.obj.answer_msg_id) - async def _handle_new_session_created(self, message, new_session): + async def _handle_new_session_created(self, message): """ Updates the current status with the received session information: @@ -520,7 +522,7 @@ class MTProtoSender: """ # TODO https://goo.gl/LMyN7A __log__.debug('Handling new session created') - self.state.salt = new_session.server_salt + self.state.salt = message.obj.server_salt def _clean_containers(self, msg_ids): """ @@ -533,13 +535,13 @@ class MTProtoSender: """ for i in reversed(range(len(self._pending_containers))): message = self._pending_containers[i] - for msg in message.request.messages: + for msg in message.obj.messages: if msg.msg_id in msg_ids: del self._pending_containers[i] del self._pending_messages[message.msg_id] break - async def _handle_ack(self, message, ack): + async def _handle_ack(self, message): """ Handles a server acknowledge about our messages. Normally these can be ignored except in the case of ``auth.logOut``: @@ -555,16 +557,17 @@ class MTProtoSender: messages are acknowledged. """ __log__.debug('Handling acknowledge') + ack = message.obj if self._pending_containers: self._clean_containers(ack.msg_ids) for msg_id in ack.msg_ids: msg = self._pending_messages.get(msg_id, None) - if msg and isinstance(msg.request, LogOutRequest): + if msg and isinstance(msg.obj, LogOutRequest): del self._pending_messages[msg_id] msg.future.set_result(True) - async def _handle_future_salts(self, message, salts): + async def _handle_future_salts(self, message): """ Handles future salt results, which don't come inside a ``rpc_result`` but are still sent through a request: @@ -577,7 +580,7 @@ class MTProtoSender: __log__.debug('Handling future salts') msg = self._pending_messages.pop(message.msg_id, None) if msg: - msg.future.set_result(salts) + msg.future.set_result(message.obj) class _ContainerQueue(asyncio.Queue): @@ -593,13 +596,13 @@ class _ContainerQueue(asyncio.Queue): """ async def get(self): result = await super().get() - if self.empty() or isinstance(result.request, MessageContainer): + if self.empty() or isinstance(result.obj, MessageContainer): return result result = [result] while not self.empty(): item = self.get_nowait() - if isinstance(item.request, MessageContainer): + if isinstance(item.obj, MessageContainer): await self.put(item) break else: diff --git a/telethon/network/mtprotostate.py b/telethon/network/mtprotostate.py index 36a9dde1..220457ee 100644 --- a/telethon/network/mtprotostate.py +++ b/telethon/network/mtprotostate.py @@ -1,3 +1,4 @@ +import logging import os import struct import time @@ -8,6 +9,8 @@ from ..errors import SecurityError, BrokenAuthKeyError from ..extensions import BinaryReader from ..tl.core import TLMessage +__log__ = logging.getLogger(__name__) + class MTProtoState: """ @@ -33,15 +36,15 @@ class MTProtoState: self._sequence = 0 self._last_msg_id = 0 - def create_message(self, request, after=None): + def create_message(self, obj, after=None): """ Creates a new `telethon.tl.tl_message.TLMessage` from the given `telethon.tl.tlobject.TLObject` instance. """ return TLMessage( msg_id=self._get_new_msg_id(), - seq_no=self._get_seq_no(request.content_related), - request=request, + seq_no=self._get_seq_no(obj.content_related), + obj=obj, after_id=after.msg_id if after else None ) @@ -100,25 +103,31 @@ class MTProtoState: msg_key = body[8:24] aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False) - data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv)) - - data.read_long() # remote_salt - if data.read_long() != self.id: - raise SecurityError('Server replied with a wrong session ID') - - remote_msg_id = data.read_long() - remote_sequence = data.read_int() - msg_len = data.read_int() - message = data.read(msg_len) + body = AES.decrypt_ige(body[24:], aes_key, aes_iv) # https://core.telegram.org/mtproto/security_guidelines # Sections "checking sha256 hash" and "message length" - our_key = sha256(self.auth_key.key[96:96 + 32] + data.get_bytes()) + our_key = sha256(self.auth_key.key[96:96 + 32] + body) if msg_key != our_key.digest()[8:24]: raise SecurityError( "Received msg_key doesn't match with expected one") - return TLMessage(remote_msg_id, remote_sequence, body=message) + reader = BinaryReader(body) + reader.read_long() # remote_salt + if reader.read_long() != self.id: + raise SecurityError('Server replied with a wrong session ID') + + remote_msg_id = reader.read_long() + remote_sequence = reader.read_int() + msg_len = reader.read_int() + before = reader.tell_position() + obj = reader.tgread_object() + if reader.tell_position() != before + msg_len: + reader.set_position(before) + __log__.warning('Data left after TLObject {}: {!r}' + .format(obj, reader.read(msg_len))) + + return TLMessage(remote_msg_id, remote_sequence, obj) def _get_new_msg_id(self): """ diff --git a/telethon/tl/core/messagecontainer.py b/telethon/tl/core/messagecontainer.py index ef1eab1e..0d56de33 100644 --- a/telethon/tl/core/messagecontainer.py +++ b/telethon/tl/core/messagecontainer.py @@ -1,7 +1,10 @@ +import logging import struct -from ..tlobject import TLObject from .tlmessage import TLMessage +from ..tlobject import TLObject + +__log__ = logging.getLogger(__name__) class MessageContainer(TLObject): @@ -26,17 +29,6 @@ class MessageContainer(TLObject): '