Add 'caused_by' property to RpcError and BadMessageError; refactor Sender to use it

This commit is contained in:
Jahongir Qurbonov 2025-09-15 13:06:01 +05:00
parent 5fe17a17e2
commit ecfb263e41
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
2 changed files with 48 additions and 17 deletions

View File

@ -96,6 +96,20 @@ class RpcError(ValueError):
""" """
return self._value return self._value
@property
def caused_by(self) -> Optional[int]:
"""
Constructor identifier of the request that caused the error, if known.
"""
return self._caused_by
@caused_by.setter
def caused_by(self, value: int) -> None:
"""
Constructor identifier of the request that caused the error, if known.
"""
self._caused_by = value
@classmethod @classmethod
def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self: def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self:
if m := re.search(r"-?\d+", error.error_message): if m := re.search(r"-?\d+", error.error_message):
@ -175,6 +189,14 @@ class BadMessageError(ValueError):
def fatal(self) -> bool: def fatal(self) -> bool:
return self._code not in NON_FATAL_MSG_IDS return self._code not in NON_FATAL_MSG_IDS
@property
def caused_by(self) -> Optional[int]:
return self._caused_by
@caused_by.setter
def caused_by(self, value: int) -> None:
self._caused_by = value
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented

View File

@ -181,6 +181,14 @@ class Sender:
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool _write_drain_pending: bool
@property
def mtp(self) -> Mtp:
return self._mtp
@mtp.setter
def mtp(self, value: Mtp) -> None:
self._mtp = value
@classmethod @classmethod
async def connect( async def connect(
cls, cls,
@ -412,20 +420,21 @@ class Sender:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(self._mtp_buffer)
for result in results: for result in results:
if isinstance(result, Update): match result:
self._process_update(result.body) case Update(body=body):
elif isinstance(result, RpcResult): self._process_update(body)
self._process_result(result) case RpcResult():
elif isinstance(result, RpcError): self._process_result(result)
self._process_error(result) case RpcError():
elif isinstance(result, BadMessageError): self._process_error(result)
self._process_bad_message(result) case BadMessageError():
elif isinstance(result, DeserializationFailure): self._process_bad_message(result)
self._process_deserialize_error(result) case DeserializationFailure():
else: self._process_deserialize_error(result)
raise RuntimeError( case _:
f"unexpected result type {type(result).__name__}: {result}" raise RuntimeError(
) f"unexpected result type {type(result).__name__}: {result}"
)
def _process_update(self, update: bytes | bytearray | memoryview) -> None: def _process_update(self, update: bytes | bytearray | memoryview) -> None:
try: try:
@ -484,7 +493,7 @@ class Sender:
req = self._pop_request(result.msg_id) req = self._pop_request(result.msg_id)
if req: if req:
result._caused_by = struct.unpack_from("<I", req.body)[0] result.caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result) req.result.set_exception(result)
else: else:
self._logger.warning( self._logger.warning(
@ -511,7 +520,7 @@ class Sender:
result.msg_id, result.msg_id,
result, result,
) )
result._caused_by = struct.unpack_from("<I", req.body)[0] result.caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result) req.result.set_exception(result)
def _process_deserialize_error(self, failure: DeserializationFailure): def _process_deserialize_error(self, failure: DeserializationFailure):
@ -600,5 +609,5 @@ async def generate_auth_key(sender: Sender) -> Sender:
time_offset = finished.time_offset time_offset = finished.time_offset
first_salt = finished.first_salt first_salt = finished.first_salt
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender.mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
return sender return sender