From 60ed7a32fed65f8c15412223fca93fb90915b52b Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 31 Aug 2023 13:23:30 +0200 Subject: [PATCH] Use proper error types in mtp --- .../telethon/_impl/mtproto/mtp/encrypted.py | 15 ++++----- .../src/telethon/_impl/mtproto/mtp/types.py | 33 +++++++++++++++---- client/src/telethon/_impl/mtsender/sender.py | 8 +---- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index f8e31124..ec23a277 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -34,7 +34,7 @@ from ...tl.mtproto.types import ( RpcAnswerUnknown, ) from ...tl.mtproto.types import RpcError as GeneratedRpcError -from ...tl.mtproto.types import RpcResult +from ...tl.mtproto.types import RpcResult as GeneratedRpcResult from ...tl.types import ( Updates, UpdatesCombined, @@ -54,7 +54,7 @@ from ..utils import ( gzip_decompress, message_requires_ack, ) -from .types import Deserialization, MsgId, Mtp, RpcError +from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult NUM_FUTURE_SALTS = 64 @@ -95,13 +95,13 @@ class Encrypted(Mtp): self._last_msg_id: int = 0 self._pending_ack: List[int] = [] self._compression_threshold = compression_threshold - self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = [] + self._rpc_results: List[Tuple[MsgId, RpcResult]] = [] self._updates: List[bytes] = [] self._buffer = bytearray() self._msg_count: int = 0 self._handlers = { - RpcResult.constructor_id(): self._handle_rpc_result, + GeneratedRpcResult.constructor_id(): self._handle_rpc_result, MsgsAck.constructor_id(): self._handle_ack, BadMsgNotification.constructor_id(): self._handle_bad_notification, BadServerSalt.constructor_id(): self._handle_bad_notification, @@ -193,7 +193,7 @@ class Encrypted(Mtp): self._handlers.get(constructor_id, self._handle_update)(message) def _handle_rpc_result(self, message: Message) -> None: - rpc_result = RpcResult.from_bytes(message.body) + rpc_result = GeneratedRpcResult.from_bytes(message.body) req_msg_id = rpc_result.req_msg_id result = rpc_result.result @@ -231,13 +231,12 @@ class Encrypted(Mtp): MsgsAck.from_bytes(message.body) def _handle_bad_notification(self, message: Message) -> None: - # TODO notify about this somehow bad_msg = AbcBadMsgNotification.from_bytes(message.body) if isinstance(bad_msg, BadServerSalt): self._rpc_results.append( ( MsgId(bad_msg.bad_msg_id), - ValueError(f"bad msg: {bad_msg.error_code}"), + BadMessage(code=bad_msg.error_code), ) ) @@ -253,7 +252,7 @@ class Encrypted(Mtp): assert isinstance(bad_msg, BadMsgNotification) self._rpc_results.append( - (MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}")) + (MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code)) ) if bad_msg.error_code in (16, 17): diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 19e79573..6c76e8f5 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -8,12 +8,6 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError MsgId = NewType("MsgId", int) -@dataclass -class Deserialization: - rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] - updates: List[bytes] - - class RpcError(ValueError): def __init__( self, @@ -61,6 +55,33 @@ class RpcError(ValueError): ) +class BadMessage(ValueError): + def __init__( + self, + *, + code: int, + caused_by: Optional[int] = None, + ) -> None: + super().__init__(f"bad msg: {code}") + + self.code = code + self.caused_by = caused_by + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return self.code == other.code + + +RpcResult = bytes | RpcError | BadMessage + + +@dataclass +class Deserialization: + rpc_results: List[Tuple[MsgId, RpcResult]] + updates: List[bytes] + + # https://core.telegram.org/mtproto/description class Mtp(ABC): @abstractmethod diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 98817b7a..88056289 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -10,7 +10,7 @@ from ..crypto.auth_key import AuthKey from ..mtproto import authentication from ..mtproto.mtp.encrypted import Encrypted from ..mtproto.mtp.plain import Plain -from ..mtproto.mtp.types import MsgId, Mtp +from ..mtproto.mtp.types import BadMessage, MsgId, Mtp, RpcError from ..mtproto.transport.abcs import MissingBytes, Transport from ..tl.abcs import Updates from ..tl.core.request import Request as RemoteCall @@ -253,15 +253,9 @@ class Sender: found = True if isinstance(ret, bytes): assert len(ret) >= 4 - elif isinstance(ret, Exception): - raise NotImplementedError elif isinstance(ret, RpcError): ret.caused_by = req.body[:4] raise ret - elif isinstance(ret, Dropped): - raise ret - elif isinstance(ret, Deserialize): - raise ret elif isinstance(ret, BadMessage): # TODO test that we resend the request req.state = NotSerialized()