Improve error handling in transports

This commit is contained in:
Lonami Exo 2023-10-13 22:59:26 +02:00
parent 42633882b5
commit b4f9d3d720
8 changed files with 64 additions and 13 deletions

View File

@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1
from .authentication import step2 as auth_step2
from .authentication import step3 as auth_step3
from .mtp import BadMessage, Deserialization, Encrypted, MsgId, Mtp, Plain, RpcError
from .transport import Abridged, Full, Intermediate, MissingBytes, Transport
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport
from .utils import DEFAULT_COMPRESSION_THRESHOLD
__all__ = [
@ -23,6 +23,7 @@ __all__ = [
"Plain",
"RpcError",
"Abridged",
"BadStatus",
"Full",
"Intermediate",
"MissingBytes",

View File

@ -1,3 +1,4 @@
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
@ -80,6 +81,25 @@ class RpcError(ValueError):
)
# https://core.telegram.org/mtproto/service_messages_about_messages
BAD_MSG_DESCRIPTIONS = {
16: "msg_id too low",
17: "msg_id too high",
18: "incorrect two lower order msg_id bits",
19: "container msg_id is the same as msg_id of a previously received message",
20: "message too old, and it cannot be verified whether the server has received a message with this msg_id or not",
32: "msg_seqno too low",
33: "msg_seqno too high",
34: "an even msg_seqno expected, but odd received",
35: "odd msg_seqno expected, but even received",
48: "incorrect server salt",
64: "invalid container",
}
RETRYABLE_MSG_IDS = {16, 17, 48}
NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
class BadMessage(ValueError):
def __init__(
self,
@ -87,15 +107,27 @@ class BadMessage(ValueError):
code: int,
caused_by: Optional[int] = None,
) -> None:
super().__init__(f"bad msg: {code}")
description = BAD_MSG_DESCRIPTIONS.get(code) or "no description available"
super().__init__(f"bad msg={code}: {description}")
self._code = code
self._caused_by = caused_by
self.severity = (
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
)
@property
def code(self) -> int:
return self._code
@property
def retryable(self) -> bool:
return self._code in RETRYABLE_MSG_IDS
@property
def fatal(self) -> bool:
return self._code not in NON_FATAL_MSG_IDS
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented

View File

@ -1,6 +1,6 @@
from .abcs import MissingBytes, Transport
from .abcs import BadStatus, MissingBytes, Transport
from .abridged import Abridged
from .full import Full
from .intermediate import Intermediate
__all__ = ["MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]
__all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"]

View File

@ -19,3 +19,9 @@ class Transport(ABC):
class MissingBytes(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadStatus(ValueError):
def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}")
self.status = status

View File

@ -1,6 +1,6 @@
import struct
from .abcs import MissingBytes, OutFn, Transport
from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Abridged(Transport):
@ -41,7 +41,7 @@ class Abridged(Transport):
raise MissingBytes(expected=1, got=0)
length = input[0]
if length < 127:
if 1 < length < 127:
header_len = 1
elif len(input) < 4:
raise MissingBytes(expected=4, got=len(input))
@ -49,6 +49,11 @@ class Abridged(Transport):
header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8
if length <= 0:
if length < 0:
raise BadStatus(status=-length)
raise ValueError(f"bad length, expected > 0, got: {length}")
length *= 4
if len(input) < header_len + length:
raise MissingBytes(expected=header_len + length, got=len(input))

View File

@ -1,7 +1,7 @@
import struct
from zlib import crc32
from .abcs import MissingBytes, OutFn, Transport
from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Full(Transport):
@ -42,6 +42,8 @@ class Full(Transport):
length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int)
if length < 12:
if length < 0:
raise BadStatus(status=-length)
raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length:

View File

@ -1,6 +1,6 @@
import struct
from .abcs import MissingBytes, OutFn, Transport
from .abcs import BadStatus, MissingBytes, OutFn, Transport
class Intermediate(Transport):
@ -41,5 +41,14 @@ class Intermediate(Transport):
if len(input) < length:
raise MissingBytes(expected=length, got=len(input))
if length <= 4:
if (
length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
):
raise BadStatus(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}")
output += memoryview(input)[4 : 4 + length]
return length + 4

View File

@ -1,5 +1,4 @@
import gzip
import struct
from typing import Optional
from ..tl.mtproto.types import GzipPacked, Message
@ -12,10 +11,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes) -> None:
if len(message) == 4:
neg_http_code = struct.unpack("<i", message)[0]
raise ValueError(f"transport error: {neg_http_code}")
elif len(message) < 20:
if len(message) < 20:
raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}"
)