mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 13:14:31 +03:00
Use proper error types in mtp
This commit is contained in:
parent
7166059132
commit
60ed7a32fe
|
@ -34,7 +34,7 @@ from ...tl.mtproto.types import (
|
|||
RpcAnswerUnknown,
|
||||
)
|
||||
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||
from ...tl.mtproto.types import RpcResult
|
||||
from ...tl.mtproto.types import RpcResult as GeneratedRpcResult
|
||||
from ...tl.types import (
|
||||
Updates,
|
||||
UpdatesCombined,
|
||||
|
@ -54,7 +54,7 @@ from ..utils import (
|
|||
gzip_decompress,
|
||||
message_requires_ack,
|
||||
)
|
||||
from .types import Deserialization, MsgId, Mtp, RpcError
|
||||
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult
|
||||
|
||||
NUM_FUTURE_SALTS = 64
|
||||
|
||||
|
@ -95,13 +95,13 @@ class Encrypted(Mtp):
|
|||
self._last_msg_id: int = 0
|
||||
self._pending_ack: List[int] = []
|
||||
self._compression_threshold = compression_threshold
|
||||
self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = []
|
||||
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
||||
self._updates: List[bytes] = []
|
||||
self._buffer = bytearray()
|
||||
self._msg_count: int = 0
|
||||
|
||||
self._handlers = {
|
||||
RpcResult.constructor_id(): self._handle_rpc_result,
|
||||
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
|
||||
MsgsAck.constructor_id(): self._handle_ack,
|
||||
BadMsgNotification.constructor_id(): self._handle_bad_notification,
|
||||
BadServerSalt.constructor_id(): self._handle_bad_notification,
|
||||
|
@ -193,7 +193,7 @@ class Encrypted(Mtp):
|
|||
self._handlers.get(constructor_id, self._handle_update)(message)
|
||||
|
||||
def _handle_rpc_result(self, message: Message) -> None:
|
||||
rpc_result = RpcResult.from_bytes(message.body)
|
||||
rpc_result = GeneratedRpcResult.from_bytes(message.body)
|
||||
req_msg_id = rpc_result.req_msg_id
|
||||
result = rpc_result.result
|
||||
|
||||
|
@ -231,13 +231,12 @@ class Encrypted(Mtp):
|
|||
MsgsAck.from_bytes(message.body)
|
||||
|
||||
def _handle_bad_notification(self, message: Message) -> None:
|
||||
# TODO notify about this somehow
|
||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||
if isinstance(bad_msg, BadServerSalt):
|
||||
self._rpc_results.append(
|
||||
(
|
||||
MsgId(bad_msg.bad_msg_id),
|
||||
ValueError(f"bad msg: {bad_msg.error_code}"),
|
||||
BadMessage(code=bad_msg.error_code),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -253,7 +252,7 @@ class Encrypted(Mtp):
|
|||
|
||||
assert isinstance(bad_msg, BadMsgNotification)
|
||||
self._rpc_results.append(
|
||||
(MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}"))
|
||||
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
|
||||
)
|
||||
|
||||
if bad_msg.error_code in (16, 17):
|
||||
|
|
|
@ -8,12 +8,6 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
|||
MsgId = NewType("MsgId", int)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Deserialization:
|
||||
rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]]
|
||||
updates: List[bytes]
|
||||
|
||||
|
||||
class RpcError(ValueError):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -61,6 +55,33 @@ class RpcError(ValueError):
|
|||
)
|
||||
|
||||
|
||||
class BadMessage(ValueError):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
code: int,
|
||||
caused_by: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(f"bad msg: {code}")
|
||||
|
||||
self.code = code
|
||||
self.caused_by = caused_by
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.code == other.code
|
||||
|
||||
|
||||
RpcResult = bytes | RpcError | BadMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Deserialization:
|
||||
rpc_results: List[Tuple[MsgId, RpcResult]]
|
||||
updates: List[bytes]
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/description
|
||||
class Mtp(ABC):
|
||||
@abstractmethod
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..crypto.auth_key import AuthKey
|
|||
from ..mtproto import authentication
|
||||
from ..mtproto.mtp.encrypted import Encrypted
|
||||
from ..mtproto.mtp.plain import Plain
|
||||
from ..mtproto.mtp.types import MsgId, Mtp
|
||||
from ..mtproto.mtp.types import BadMessage, MsgId, Mtp, RpcError
|
||||
from ..mtproto.transport.abcs import MissingBytes, Transport
|
||||
from ..tl.abcs import Updates
|
||||
from ..tl.core.request import Request as RemoteCall
|
||||
|
@ -253,15 +253,9 @@ class Sender:
|
|||
found = True
|
||||
if isinstance(ret, bytes):
|
||||
assert len(ret) >= 4
|
||||
elif isinstance(ret, Exception):
|
||||
raise NotImplementedError
|
||||
elif isinstance(ret, RpcError):
|
||||
ret.caused_by = req.body[:4]
|
||||
raise ret
|
||||
elif isinstance(ret, Dropped):
|
||||
raise ret
|
||||
elif isinstance(ret, Deserialize):
|
||||
raise ret
|
||||
elif isinstance(ret, BadMessage):
|
||||
# TODO test that we resend the request
|
||||
req.state = NotSerialized()
|
||||
|
|
Loading…
Reference in New Issue
Block a user