From b4f9d3d720775a7c378a59f5f98138b47481f592 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 13 Oct 2023 22:59:26 +0200 Subject: [PATCH] Improve error handling in transports --- client/src/telethon/_impl/mtproto/__init__.py | 3 +- .../src/telethon/_impl/mtproto/mtp/types.py | 34 ++++++++++++++++++- .../_impl/mtproto/transport/__init__.py | 4 +-- .../telethon/_impl/mtproto/transport/abcs.py | 6 ++++ .../_impl/mtproto/transport/abridged.py | 9 +++-- .../telethon/_impl/mtproto/transport/full.py | 4 ++- .../_impl/mtproto/transport/intermediate.py | 11 +++++- client/src/telethon/_impl/mtproto/utils.py | 6 +--- 8 files changed, 64 insertions(+), 13 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/__init__.py b/client/src/telethon/_impl/mtproto/__init__.py index 94a723ce..08e53086 100644 --- a/client/src/telethon/_impl/mtproto/__init__.py +++ b/client/src/telethon/_impl/mtproto/__init__.py @@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1 from .authentication import step2 as auth_step2 from .authentication import step3 as auth_step3 from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError -from .transport import Abridged, Full, Intermediate, MissingBytes, Transport +from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport from .utils import DEFAULT_COMPRESSION_THRESHOLD __all__ = [ @@ -23,6 +23,7 @@ __all__ = [ "Plain", "RpcError", "Abridged", + "BadStatus", "Full", "Intermediate", "MissingBytes", diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index bf13638b..0d557618 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -1,3 +1,4 @@ +import logging import re from abc import ABC, abstractmethod from dataclasses import dataclass @@ -80,6 +81,25 @@ class RpcError(ValueError): ) +# https://core.telegram.org/mtproto/service_messages_about_messages +BAD_MSG_DESCRIPTIONS = { + 16: "msg_id too low", + 17: "msg_id too high", + 18: "incorrect two lower order msg_id bits", + 19: "container msg_id is the same as msg_id of a previously received message", + 20: "message too old, and it cannot be verified whether the server has received a message with this msg_id or not", + 32: "msg_seqno too low", + 33: "msg_seqno too high", + 34: "an even msg_seqno expected, but odd received", + 35: "odd msg_seqno expected, but even received", + 48: "incorrect server salt", + 64: "invalid container", +} + +RETRYABLE_MSG_IDS = {16, 17, 48} +NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33} + + class BadMessage(ValueError): def __init__( self, @@ -87,15 +107,27 @@ class BadMessage(ValueError): code: int, caused_by: Optional[int] = None, ) -> None: - super().__init__(f"bad msg: {code}") + description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available" + super().__init__(f"bad msg={code}: {description}") self._code = code self._caused_by = caused_by + self.severity = ( + logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR + ) @property def code(self) -> int: return self._code + @property + def retryable(self) -> bool: + return self._code in RETRYABLE_MSG_IDS + + @property + def fatal(self) -> bool: + return self._code not in NON_FATAL_MSG_IDS + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented diff --git a/client/src/telethon/_impl/mtproto/transport/__init__.py b/client/src/telethon/_impl/mtproto/transport/__init__.py index e3120625..50979449 100644 --- a/client/src/telethon/_impl/mtproto/transport/__init__.py +++ b/client/src/telethon/_impl/mtproto/transport/__init__.py @@ -1,6 +1,6 @@ -from .abcs import MissingBytes, Transport +from .abcs import BadStatus, MissingBytes, Transport from .abridged import Abridged from .full import Full from .intermediate import Intermediate -__all__ = ["MissingBytes", "Transport", "Abridged", "Full", "Intermediate"] +__all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"] diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index 36d66823..c373d0f3 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -19,3 +19,9 @@ class Transport(ABC): class MissingBytes(ValueError): def __init__(self, *, expected: int, got: int) -> None: super().__init__(f"missing bytes, expected: {expected}, got: {got}") + + +class BadStatus(ValueError): + def __init__(self, *, status: int) -> None: + super().__init__(f"transport reported bad status: {status}") + self.status = status diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index 9d43fe32..7dc32d67 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -1,6 +1,6 @@ import struct -from .abcs import MissingBytes, OutFn, Transport +from .abcs import BadStatus, MissingBytes, OutFn, Transport class Abridged(Transport): @@ -41,7 +41,7 @@ class Abridged(Transport): raise MissingBytes(expected=1, got=0) length = input[0] - if length < 127: + if 1 < length < 127: header_len = 1 elif len(input) < 4: raise MissingBytes(expected=4, got=len(input)) @@ -49,6 +49,11 @@ class Abridged(Transport): header_len = 4 length = struct.unpack_from("> 8 + if length <= 0: + if length < 0: + raise BadStatus(status=-length) + raise ValueError(f"bad length, expected > 0, got: {length}") + length *= 4 if len(input) < header_len + length: raise MissingBytes(expected=header_len + length, got=len(input)) diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index 4ded6f37..ce1f5a14 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -1,7 +1,7 @@ import struct from zlib import crc32 -from .abcs import MissingBytes, OutFn, Transport +from .abcs import BadStatus, MissingBytes, OutFn, Transport class Full(Transport): @@ -42,6 +42,8 @@ class Full(Transport): length = struct.unpack_from(" 12, got: {length}") if len(input) < length: diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 241d094f..73288350 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -1,6 +1,6 @@ import struct -from .abcs import MissingBytes, OutFn, Transport +from .abcs import BadStatus, MissingBytes, OutFn, Transport class Intermediate(Transport): @@ -41,5 +41,14 @@ class Intermediate(Transport): if len(input) < length: raise MissingBytes(expected=length, got=len(input)) + if length <= 4: + if ( + length >= 4 + and (status := struct.unpack(" 0, got: {length}") + output += memoryview(input)[4 : 4 + length] return length + 4 diff --git a/client/src/telethon/_impl/mtproto/utils.py b/client/src/telethon/_impl/mtproto/utils.py index 25a78b4b..691852e7 100644 --- a/client/src/telethon/_impl/mtproto/utils.py +++ b/client/src/telethon/_impl/mtproto/utils.py @@ -1,5 +1,4 @@ import gzip -import struct from typing import Optional from ..tl.mtproto.types import GzipPacked, Message @@ -12,10 +11,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes def check_message_buffer(message: bytes) -> None: - if len(message) == 4: - neg_http_code = struct.unpack("