Use proper error types in mtp

This commit is contained in:
Lonami Exo 2023-08-31 13:23:30 +02:00
parent 7166059132
commit 60ed7a32fe
3 changed files with 35 additions and 21 deletions

View File

@ -34,7 +34,7 @@ from ...tl.mtproto.types import (
RpcAnswerUnknown, RpcAnswerUnknown,
) )
from ...tl.mtproto.types import RpcError as GeneratedRpcError from ...tl.mtproto.types import RpcError as GeneratedRpcError
from ...tl.mtproto.types import RpcResult from ...tl.mtproto.types import RpcResult as GeneratedRpcResult
from ...tl.types import ( from ...tl.types import (
Updates, Updates,
UpdatesCombined, UpdatesCombined,
@ -54,7 +54,7 @@ from ..utils import (
gzip_decompress, gzip_decompress,
message_requires_ack, message_requires_ack,
) )
from .types import Deserialization, MsgId, Mtp, RpcError from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult
NUM_FUTURE_SALTS = 64 NUM_FUTURE_SALTS = 64
@ -95,13 +95,13 @@ class Encrypted(Mtp):
self._last_msg_id: int = 0 self._last_msg_id: int = 0
self._pending_ack: List[int] = [] self._pending_ack: List[int] = []
self._compression_threshold = compression_threshold self._compression_threshold = compression_threshold
self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = [] self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
self._updates: List[bytes] = [] self._updates: List[bytes] = []
self._buffer = bytearray() self._buffer = bytearray()
self._msg_count: int = 0 self._msg_count: int = 0
self._handlers = { self._handlers = {
RpcResult.constructor_id(): self._handle_rpc_result, GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
MsgsAck.constructor_id(): self._handle_ack, MsgsAck.constructor_id(): self._handle_ack,
BadMsgNotification.constructor_id(): self._handle_bad_notification, BadMsgNotification.constructor_id(): self._handle_bad_notification,
BadServerSalt.constructor_id(): self._handle_bad_notification, BadServerSalt.constructor_id(): self._handle_bad_notification,
@ -193,7 +193,7 @@ class Encrypted(Mtp):
self._handlers.get(constructor_id, self._handle_update)(message) self._handlers.get(constructor_id, self._handle_update)(message)
def _handle_rpc_result(self, message: Message) -> None: def _handle_rpc_result(self, message: Message) -> None:
rpc_result = RpcResult.from_bytes(message.body) rpc_result = GeneratedRpcResult.from_bytes(message.body)
req_msg_id = rpc_result.req_msg_id req_msg_id = rpc_result.req_msg_id
result = rpc_result.result result = rpc_result.result
@ -231,13 +231,12 @@ class Encrypted(Mtp):
MsgsAck.from_bytes(message.body) MsgsAck.from_bytes(message.body)
def _handle_bad_notification(self, message: Message) -> None: def _handle_bad_notification(self, message: Message) -> None:
# TODO notify about this somehow
bad_msg = AbcBadMsgNotification.from_bytes(message.body) bad_msg = AbcBadMsgNotification.from_bytes(message.body)
if isinstance(bad_msg, BadServerSalt): if isinstance(bad_msg, BadServerSalt):
self._rpc_results.append( self._rpc_results.append(
( (
MsgId(bad_msg.bad_msg_id), MsgId(bad_msg.bad_msg_id),
ValueError(f"bad msg: {bad_msg.error_code}"), BadMessage(code=bad_msg.error_code),
) )
) )
@ -253,7 +252,7 @@ class Encrypted(Mtp):
assert isinstance(bad_msg, BadMsgNotification) assert isinstance(bad_msg, BadMsgNotification)
self._rpc_results.append( self._rpc_results.append(
(MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}")) (MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
) )
if bad_msg.error_code in (16, 17): if bad_msg.error_code in (16, 17):

View File

@ -8,12 +8,6 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int) MsgId = NewType("MsgId", int)
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]]
updates: List[bytes]
class RpcError(ValueError): class RpcError(ValueError):
def __init__( def __init__(
self, self,
@ -61,6 +55,33 @@ class RpcError(ValueError):
) )
class BadMessage(ValueError):
def __init__(
self,
*,
code: int,
caused_by: Optional[int] = None,
) -> None:
super().__init__(f"bad msg: {code}")
self.code = code
self.caused_by = caused_by
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.code == other.code
RpcResult = bytes | RpcError | BadMessage
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, RpcResult]]
updates: List[bytes]
# https://core.telegram.org/mtproto/description # https://core.telegram.org/mtproto/description
class Mtp(ABC): class Mtp(ABC):
@abstractmethod @abstractmethod

View File

@ -10,7 +10,7 @@ from ..crypto.auth_key import AuthKey
from ..mtproto import authentication from ..mtproto import authentication
from ..mtproto.mtp.encrypted import Encrypted from ..mtproto.mtp.encrypted import Encrypted
from ..mtproto.mtp.plain import Plain from ..mtproto.mtp.plain import Plain
from ..mtproto.mtp.types import MsgId, Mtp from ..mtproto.mtp.types import BadMessage, MsgId, Mtp, RpcError
from ..mtproto.transport.abcs import MissingBytes, Transport from ..mtproto.transport.abcs import MissingBytes, Transport
from ..tl.abcs import Updates from ..tl.abcs import Updates
from ..tl.core.request import Request as RemoteCall from ..tl.core.request import Request as RemoteCall
@ -253,15 +253,9 @@ class Sender:
found = True found = True
if isinstance(ret, bytes): if isinstance(ret, bytes):
assert len(ret) >= 4 assert len(ret) >= 4
elif isinstance(ret, Exception):
raise NotImplementedError
elif isinstance(ret, RpcError): elif isinstance(ret, RpcError):
ret.caused_by = req.body[:4] ret.caused_by = req.body[:4]
raise ret raise ret
elif isinstance(ret, Dropped):
raise ret
elif isinstance(ret, Deserialize):
raise ret
elif isinstance(ret, BadMessage): elif isinstance(ret, BadMessage):
# TODO test that we resend the request # TODO test that we resend the request
req.state = NotSerialized() req.state = NotSerialized()