Add 'caused_by' property to RpcError and BadMessageError; refactor Sender to use it

This commit is contained in:
Jahongir Qurbonov 2025-09-15 13:06:01 +05:00
parent 5fe17a17e2
commit ecfb263e41
No known key found for this signature in database
GPG Key ID: 256976CED13D5F2D
2 changed files with 48 additions and 17 deletions

View File

@ -96,6 +96,20 @@ class RpcError(ValueError):
"""
return self._value
@property
def caused_by(self) -> Optional[int]:
"""
Constructor identifier of the request that caused the error, if known.
"""
return self._caused_by
@caused_by.setter
def caused_by(self, value: int) -> None:
"""
Constructor identifier of the request that caused the error, if known.
"""
self._caused_by = value
@classmethod
def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self:
if m := re.search(r"-?\d+", error.error_message):
@ -175,6 +189,14 @@ class BadMessageError(ValueError):
def fatal(self) -> bool:
return self._code not in NON_FATAL_MSG_IDS
@property
def caused_by(self) -> Optional[int]:
return self._caused_by
@caused_by.setter
def caused_by(self, value: int) -> None:
self._caused_by = value
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented

View File

@ -181,6 +181,14 @@ class Sender:
_read_buffer: bytearray
_write_drain_pending: bool
@property
def mtp(self) -> Mtp:
return self._mtp
@mtp.setter
def mtp(self, value: Mtp) -> None:
self._mtp = value
@classmethod
async def connect(
cls,
@ -412,20 +420,21 @@ class Sender:
results = self._mtp.deserialize(self._mtp_buffer)
for result in results:
if isinstance(result, Update):
self._process_update(result.body)
elif isinstance(result, RpcResult):
self._process_result(result)
elif isinstance(result, RpcError):
self._process_error(result)
elif isinstance(result, BadMessageError):
self._process_bad_message(result)
elif isinstance(result, DeserializationFailure):
self._process_deserialize_error(result)
else:
raise RuntimeError(
f"unexpected result type {type(result).__name__}: {result}"
)
match result:
case Update(body=body):
self._process_update(body)
case RpcResult():
self._process_result(result)
case RpcError():
self._process_error(result)
case BadMessageError():
self._process_bad_message(result)
case DeserializationFailure():
self._process_deserialize_error(result)
case _:
raise RuntimeError(
f"unexpected result type {type(result).__name__}: {result}"
)
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
try:
@ -484,7 +493,7 @@ class Sender:
req = self._pop_request(result.msg_id)
if req:
result._caused_by = struct.unpack_from("<I", req.body)[0]
result.caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result)
else:
self._logger.warning(
@ -511,7 +520,7 @@ class Sender:
result.msg_id,
result,
)
result._caused_by = struct.unpack_from("<I", req.body)[0]
result.caused_by = struct.unpack_from("<I", req.body)[0]
req.result.set_exception(result)
def _process_deserialize_error(self, failure: DeserializationFailure):
@ -600,5 +609,5 @@ async def generate_auth_key(sender: Sender) -> Sender:
time_offset = finished.time_offset
first_salt = finished.first_salt
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender.mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
return sender