mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-07-15 18:42:23 +03:00
Use proper error types in mtp
This commit is contained in:
parent
7166059132
commit
60ed7a32fe
|
@ -34,7 +34,7 @@ from ...tl.mtproto.types import (
|
||||||
RpcAnswerUnknown,
|
RpcAnswerUnknown,
|
||||||
)
|
)
|
||||||
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
from ...tl.mtproto.types import RpcResult
|
from ...tl.mtproto.types import RpcResult as GeneratedRpcResult
|
||||||
from ...tl.types import (
|
from ...tl.types import (
|
||||||
Updates,
|
Updates,
|
||||||
UpdatesCombined,
|
UpdatesCombined,
|
||||||
|
@ -54,7 +54,7 @@ from ..utils import (
|
||||||
gzip_decompress,
|
gzip_decompress,
|
||||||
message_requires_ack,
|
message_requires_ack,
|
||||||
)
|
)
|
||||||
from .types import Deserialization, MsgId, Mtp, RpcError
|
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult
|
||||||
|
|
||||||
NUM_FUTURE_SALTS = 64
|
NUM_FUTURE_SALTS = 64
|
||||||
|
|
||||||
|
@ -95,13 +95,13 @@ class Encrypted(Mtp):
|
||||||
self._last_msg_id: int = 0
|
self._last_msg_id: int = 0
|
||||||
self._pending_ack: List[int] = []
|
self._pending_ack: List[int] = []
|
||||||
self._compression_threshold = compression_threshold
|
self._compression_threshold = compression_threshold
|
||||||
self._rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]] = []
|
self._rpc_results: List[Tuple[MsgId, RpcResult]] = []
|
||||||
self._updates: List[bytes] = []
|
self._updates: List[bytes] = []
|
||||||
self._buffer = bytearray()
|
self._buffer = bytearray()
|
||||||
self._msg_count: int = 0
|
self._msg_count: int = 0
|
||||||
|
|
||||||
self._handlers = {
|
self._handlers = {
|
||||||
RpcResult.constructor_id(): self._handle_rpc_result,
|
GeneratedRpcResult.constructor_id(): self._handle_rpc_result,
|
||||||
MsgsAck.constructor_id(): self._handle_ack,
|
MsgsAck.constructor_id(): self._handle_ack,
|
||||||
BadMsgNotification.constructor_id(): self._handle_bad_notification,
|
BadMsgNotification.constructor_id(): self._handle_bad_notification,
|
||||||
BadServerSalt.constructor_id(): self._handle_bad_notification,
|
BadServerSalt.constructor_id(): self._handle_bad_notification,
|
||||||
|
@ -193,7 +193,7 @@ class Encrypted(Mtp):
|
||||||
self._handlers.get(constructor_id, self._handle_update)(message)
|
self._handlers.get(constructor_id, self._handle_update)(message)
|
||||||
|
|
||||||
def _handle_rpc_result(self, message: Message) -> None:
|
def _handle_rpc_result(self, message: Message) -> None:
|
||||||
rpc_result = RpcResult.from_bytes(message.body)
|
rpc_result = GeneratedRpcResult.from_bytes(message.body)
|
||||||
req_msg_id = rpc_result.req_msg_id
|
req_msg_id = rpc_result.req_msg_id
|
||||||
result = rpc_result.result
|
result = rpc_result.result
|
||||||
|
|
||||||
|
@ -231,13 +231,12 @@ class Encrypted(Mtp):
|
||||||
MsgsAck.from_bytes(message.body)
|
MsgsAck.from_bytes(message.body)
|
||||||
|
|
||||||
def _handle_bad_notification(self, message: Message) -> None:
|
def _handle_bad_notification(self, message: Message) -> None:
|
||||||
# TODO notify about this somehow
|
|
||||||
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
bad_msg = AbcBadMsgNotification.from_bytes(message.body)
|
||||||
if isinstance(bad_msg, BadServerSalt):
|
if isinstance(bad_msg, BadServerSalt):
|
||||||
self._rpc_results.append(
|
self._rpc_results.append(
|
||||||
(
|
(
|
||||||
MsgId(bad_msg.bad_msg_id),
|
MsgId(bad_msg.bad_msg_id),
|
||||||
ValueError(f"bad msg: {bad_msg.error_code}"),
|
BadMessage(code=bad_msg.error_code),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -253,7 +252,7 @@ class Encrypted(Mtp):
|
||||||
|
|
||||||
assert isinstance(bad_msg, BadMsgNotification)
|
assert isinstance(bad_msg, BadMsgNotification)
|
||||||
self._rpc_results.append(
|
self._rpc_results.append(
|
||||||
(MsgId(bad_msg.bad_msg_id), ValueError(f"bad msg: {bad_msg.error_code}"))
|
(MsgId(bad_msg.bad_msg_id), BadMessage(code=bad_msg.error_code))
|
||||||
)
|
)
|
||||||
|
|
||||||
if bad_msg.error_code in (16, 17):
|
if bad_msg.error_code in (16, 17):
|
||||||
|
|
|
@ -8,12 +8,6 @@ from ...tl.mtproto.types import RpcError as GeneratedRpcError
|
||||||
MsgId = NewType("MsgId", int)
|
MsgId = NewType("MsgId", int)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Deserialization:
|
|
||||||
rpc_results: List[Tuple[MsgId, Union[bytes, ValueError]]]
|
|
||||||
updates: List[bytes]
|
|
||||||
|
|
||||||
|
|
||||||
class RpcError(ValueError):
|
class RpcError(ValueError):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -61,6 +55,33 @@ class RpcError(ValueError):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BadMessage(ValueError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
code: int,
|
||||||
|
caused_by: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(f"bad msg: {code}")
|
||||||
|
|
||||||
|
self.code = code
|
||||||
|
self.caused_by = caused_by
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, self.__class__):
|
||||||
|
return NotImplemented
|
||||||
|
return self.code == other.code
|
||||||
|
|
||||||
|
|
||||||
|
RpcResult = bytes | RpcError | BadMessage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Deserialization:
|
||||||
|
rpc_results: List[Tuple[MsgId, RpcResult]]
|
||||||
|
updates: List[bytes]
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description
|
# https://core.telegram.org/mtproto/description
|
||||||
class Mtp(ABC):
|
class Mtp(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..crypto.auth_key import AuthKey
|
||||||
from ..mtproto import authentication
|
from ..mtproto import authentication
|
||||||
from ..mtproto.mtp.encrypted import Encrypted
|
from ..mtproto.mtp.encrypted import Encrypted
|
||||||
from ..mtproto.mtp.plain import Plain
|
from ..mtproto.mtp.plain import Plain
|
||||||
from ..mtproto.mtp.types import MsgId, Mtp
|
from ..mtproto.mtp.types import BadMessage, MsgId, Mtp, RpcError
|
||||||
from ..mtproto.transport.abcs import MissingBytes, Transport
|
from ..mtproto.transport.abcs import MissingBytes, Transport
|
||||||
from ..tl.abcs import Updates
|
from ..tl.abcs import Updates
|
||||||
from ..tl.core.request import Request as RemoteCall
|
from ..tl.core.request import Request as RemoteCall
|
||||||
|
@ -253,15 +253,9 @@ class Sender:
|
||||||
found = True
|
found = True
|
||||||
if isinstance(ret, bytes):
|
if isinstance(ret, bytes):
|
||||||
assert len(ret) >= 4
|
assert len(ret) >= 4
|
||||||
elif isinstance(ret, Exception):
|
|
||||||
raise NotImplementedError
|
|
||||||
elif isinstance(ret, RpcError):
|
elif isinstance(ret, RpcError):
|
||||||
ret.caused_by = req.body[:4]
|
ret.caused_by = req.body[:4]
|
||||||
raise ret
|
raise ret
|
||||||
elif isinstance(ret, Dropped):
|
|
||||||
raise ret
|
|
||||||
elif isinstance(ret, Deserialize):
|
|
||||||
raise ret
|
|
||||||
elif isinstance(ret, BadMessage):
|
elif isinstance(ret, BadMessage):
|
||||||
# TODO test that we resend the request
|
# TODO test that we resend the request
|
||||||
req.state = NotSerialized()
|
req.state = NotSerialized()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user