From 1c2aafce2a0227a565514ee365a014efff1c17ba Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 1 Jun 2025 19:55:25 +0500 Subject: [PATCH] Refactor error handling in crypto and mtproto modules; introduce custom error classes and improve deserialization error processing in Sender class --- client/src/telethon/_impl/crypto/crypto.py | 24 +++- .../telethon/_impl/mtproto/mtp/encrypted.py | 73 +++++++++--- .../src/telethon/_impl/mtproto/mtp/plain.py | 9 +- .../src/telethon/_impl/mtproto/mtp/types.py | 107 +++++++++++++++++- client/src/telethon/_impl/mtsender/errors.py | 0 client/src/telethon/_impl/mtsender/sender.py | 18 ++- 6 files changed, 206 insertions(+), 25 deletions(-) create mode 100644 client/src/telethon/_impl/mtsender/errors.py diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 17fd788e..1c47962e 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -7,6 +7,24 @@ from .aes import ige_decrypt, ige_encrypt from .auth_key import AuthKey +class InvalidBufferError(ValueError): + def __init__(self) -> None: + super().__init__("Invalid ciphertext buffer length") + + +class AuthKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("Server authkey mismatches with ours") + + +class MsgKeyMismatchError(ValueError): + def __init__(self) -> None: + super().__init__("Server msgkey mismatches with ours") + + +CryptoError = InvalidBufferError | AuthKeyMismatchError | MsgKeyMismatchError + + # "where x = 0 for messages from client to server and x = 8 for those from server to client" class Side(IntEnum): CLIENT = 0 @@ -77,14 +95,14 @@ def decrypt_data_v2( x = int(side) if len(ciphertext) < 24 or (len(ciphertext) - 24) % 16 != 0: - raise ValueError("invalid ciphertext buffer length") + raise InvalidBufferError() # salt, session_id and sequence_number should also be checked. # However, not doing so has worked fine for years. key_id = ciphertext[:8] if auth_key.key_id != key_id: - raise ValueError("server authkey mismatches with ours") + raise AuthKeyMismatchError() msg_key = ciphertext[8:24] key, iv = calc_key(auth_key, msg_key, side) @@ -93,7 +111,7 @@ def decrypt_data_v2( # https://core.telegram.org/mtproto/security_guidelines#mtproto-encrypted-messages our_key = sha256(auth_key.data[x + 88 : x + 120] + plaintext).digest() if msg_key != our_key[8:24]: - raise ValueError("server msgkey mismatches with ours") + raise MsgKeyMismatchError() return plaintext diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index abcb7015..1f94bb3b 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -62,11 +62,15 @@ from ..utils import ( ) from .types import ( BadMessageError, + DecompressionFailed, Deserialization, + DeserializationFailure, + MsgBufferTooSmall, MsgId, Mtp, RpcError, RpcResult, + UnexpectedConstructor, Update, ) @@ -85,6 +89,7 @@ UPDATE_IDS = { AffectedFoundMessages.constructor_id(), AffectedHistory.constructor_id(), AffectedMessages.constructor_id(), + # TODO InvitedUsers } HEADER_LEN = 8 + 8 # salt, client_id @@ -151,7 +156,7 @@ class Encrypted(Mtp): self._last_msg_id: int self._in_pending_ack: list[int] = [] self._msg_count: int - self._reset_session() + self.reset() @property def auth_key(self) -> bytes: @@ -166,13 +171,6 @@ class Encrypted(Mtp): def _adjusted_now(self) -> float: return time.time() + self._time_offset - def _reset_session(self) -> None: - self._client_id = struct.unpack(" int: new_msg_id = int(self._adjusted_now() * 0x100000000) if self._last_msg_id >= new_msg_id: @@ -245,12 +243,38 @@ class Encrypted(Mtp): result = rpc_result.result msg_id = MsgId(req_msg_id) - inner_constructor = struct.unpack_from(" None: + raise RuntimeError("msg_copy should not be used") + def _handle_gzip_packed(self, message: Message) -> None: container = GzipPacked.from_bytes(message.body) inner_body = gzip_decompress(container) @@ -459,3 +492,11 @@ class Encrypted(Mtp): result = self._deserialization[:] self._deserialization.clear() return result + + def reset(self) -> None: + self._client_id = struct.unpack("= 0, got: {length}") - if 20 + length > len(payload): - raise ValueError( - f"message too short, expected: {20 + length}, got {len(payload)}" - ) + if 20 + length > (lp := len(payload)): + raise ValueError(f"message too short, expected: {20 + length}, got {lp}") return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] + + def reset(self) -> None: + self._buffer.clear() diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 0f624c7c..7fe4b209 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -5,6 +5,7 @@ from typing import NewType, Optional from typing_extensions import Self +from ...crypto.crypto import CryptoError from ...tl.mtproto.types import RpcError as GeneratedRpcError MsgId = NewType("MsgId", int) @@ -180,7 +181,105 @@ class BadMessageError(ValueError): return self._code == other._code -Deserialization = Update | RpcResult | RpcError | BadMessageError +DeserializationError = ValueError + + +class DeserializationFailure: + __slots__ = ("msg_id", "error") + + def __init__(self, msg_id: MsgId, error: DeserializationError) -> None: + self.msg_id = msg_id + self.error = error + + +Deserialization = ( + Update | RpcResult | RpcError | BadMessageError | DeserializationFailure +) + + +# Deserialization errors are not fatal, so we don't subclass RpcError. +class BadAuthKeyError(DeserializationError): + def __init__(self, *args: object, got: int, expected: int) -> None: + super().__init__(f"Bad server auth key (got {got}, expected {expected})", *args) + self._got = got + self._expected = expected + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class BadMsgIdError(DeserializationError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"Bad server message id (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class NegativeLengthError(DeserializationError): + def __init__(self, *args: object, got: int) -> None: + super().__init__(f"Bad server message length (got {got})", *args) + self._got = got + + @property + def got(self): + return self._got + + +class TooLongMsgError(DeserializationError): + __slots__ = ("expected", "got") + + def __init__(self, *args: object, got: int, max_length: int) -> None: + super().__init__( + f"Bad server message length (got {got}, when at most it should be {max_length})", + *args, + ) + self._got = got + self._expected = max_length + + @property + def got(self): + return self._got + + @property + def expected(self): + return self._expected + + +class MsgBufferTooSmall(DeserializationError): + def __init__(self, *args: object) -> None: + super().__init__( + "Server responded with a payload that's too small to fit a valid message", + *args, + ) + + +class DecompressionFailed(DeserializationError): + def __init__(self, *args: object) -> None: + super().__init__("Failed to decompress server's data", *args) + + +class UnexpectedConstructor(DeserializationError): + def __init__(self, *args: object, id: int) -> None: + super().__init__(f"Unexpected constructor: {id:08x}", *args) + + +class DecryptionError(DeserializationError): + def __init__(self, *args: object, error: CryptoError) -> None: + super().__init__(f"failed to decrypt message: {error}", *args) + + self._error = error + + @property + def error(self): + return self._error # https://core.telegram.org/mtproto/description @@ -209,3 +308,9 @@ class Mtp(ABC): """ Deserialize incoming buffer payload. """ + + @abstractmethod + def reset(self) -> None: + """ + Reset the internal buffer. + """ diff --git a/client/src/telethon/_impl/mtsender/errors.py b/client/src/telethon/_impl/mtsender/errors.py new file mode 100644 index 00000000..e69de29b diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 4757ba22..0361c635 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -24,6 +24,7 @@ from ..mtproto import ( Update, authentication, ) +from ..mtproto.mtp.types import DeserializationFailure from ..tl import Request as RemoteCall from ..tl.abcs import Updates from ..tl.core import Serializable @@ -334,8 +335,12 @@ class Sender: self._process_result(result) elif isinstance(result, RpcError): self._process_error(result) - else: + elif isinstance(result, BadMessageError): self._process_bad_message(result) + elif isinstance(result, DeserializationFailure): + self._process_deserialize_error(result) + else: + raise RuntimeError(f"Unexpected result: {result}") def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -424,6 +429,17 @@ class Sender: result._caused_by = struct.unpack_from(" Optional[Request[object]]: for req in self._requests: if isinstance(req.state, Serialized) and req.state.msg_id == msg_id: