Use proper error types in mtp

This commit is contained in:
Lonami Exo 2023-08-31 13:23:30 +02:00
parent 7166059132
commit 60ed7a32fe
3 changed files with 35 additions and 21 deletions

View File

@ -34,7 +34,7 @@ from ...tl.mtproto.types import (
RpcAnswerUnknown,
)
from ...tl.mtproto.types import RpcError as GeneratedRpcError
from ...tl.mtproto.types import RpcResult
from ...tl.mtproto.types import RpcResult as GeneratedRpcResult
from ...tl.types import (
Updates,
UpdatesCombined,
@ -54,7 +54,7 @@ from ..utils import (
gzip_decompress,
message_requires_ack,
)
from .types import Deserialization, MsgId, Mtp, RpcError
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult
NUM_FUTURE_SALTS = 64
@ -95,13 +95,13 @@ class Encrypted(Mtp):
self._last_msg_id: int = 0
self._pending_ack: List[int] = []
self._compression_threshold = compression_threshold
self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = []
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
self._updates: List[bytes] = []
self._buffer = bytearray()
self._msg_count: int = 0
self._handlers = {
RpcResult.constructor_id(): self._handle_rpc_result,
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
MsgsAck.constructor_id(): self._handle_ack,
BadMsgNotification.constructor_id(): self._handle_bad_notification,
BadServerSalt.constructor_id(): self._handle_bad_notification,
@ -193,7 +193,7 @@ class Encrypted(Mtp):
self._handlers.get(constructor_id, self._handle_update)(message)
def _handle_rpc_result(self, message: Message) -> None:
rpc_result = RpcResult.from_bytes(message.body)
rpc_result = GeneratedRpcResult.from_bytes(message.body)
req_msg_id = rpc_result.req_msg_id
result = rpc_result.result
@ -231,13 +231,12 @@ class Encrypted(Mtp):
MsgsAck.from_bytes(message.body)
def _handle_bad_notification(self, message: Message) -> None:
# TODO notify about this somehow
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
if isinstance(bad_msg, BadServerSalt):
self._rpc_results.append(
(
MsgId(bad_msg.bad_msg_id),
ValueError(f"bad msg: {bad_msg.error_code}"),
BadMessage(code=bad_msg.error_code),
)
)
@ -253,7 +252,7 @@ class Encrypted(Mtp):
assert isinstance(bad_msg, BadMsgNotification)
self._rpc_results.append(
(MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}"))
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
)
if bad_msg.error_code in (16, 17):

View File

@ -8,12 +8,6 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
MsgId = NewType("MsgId", int)
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]]
updates: List[bytes]
class RpcError(ValueError):
def __init__(
self,
@ -61,6 +55,33 @@ class RpcError(ValueError):
)
class BadMessage(ValueError):
def __init__(
self,
*,
code: int,
caused_by: Optional[int] = None,
) -> None:
super().__init__(f"bad msg: {code}")
self.code = code
self.caused_by = caused_by
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.code == other.code
RpcResult = bytes | RpcError | BadMessage
@dataclass
class Deserialization:
rpc_results: List[Tuple[MsgId, RpcResult]]
updates: List[bytes]
# https://core.telegram.org/mtproto/description
class Mtp(ABC):
@abstractmethod

View File

@ -10,7 +10,7 @@ from ..crypto.auth_key import AuthKey
from ..mtproto import authentication
from ..mtproto.mtp.encrypted import Encrypted
from ..mtproto.mtp.plain import Plain
from ..mtproto.mtp.types import MsgId, Mtp
from ..mtproto.mtp.types import BadMessage, MsgId, Mtp, RpcError
from ..mtproto.transport.abcs import MissingBytes, Transport
from ..tl.abcs import Updates
from ..tl.core.request import Request as RemoteCall
@ -253,15 +253,9 @@ class Sender:
found = True
if isinstance(ret, bytes):
assert len(ret) >= 4
elif isinstance(ret, Exception):
raise NotImplementedError
elif isinstance(ret, RpcError):
ret.caused_by = req.body[:4]
raise ret
elif isinstance(ret, Dropped):
raise ret
elif isinstance(ret, Deserialize):
raise ret
elif isinstance(ret, BadMessage):
# TODO test that we resend the request
req.state = NotSerialized()