From ecfb263e41ed599ac21edfe63660a98799f6c44a Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Mon, 15 Sep 2025 13:06:01 +0500 Subject: [PATCH] Add 'caused_by' property to RpcError and BadMessageError; refactor Sender to use it --- .../src/telethon/_impl/mtproto/mtp/types.py | 22 ++++++++++ client/src/telethon/_impl/mtsender/sender.py | 43 +++++++++++-------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 4d4b43bc..18578e2e 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -96,6 +96,20 @@ class RpcError(ValueError): """ return self._value + @property + def caused_by(self) -> Optional[int]: + """ + Constructor identifier of the request that caused the error, if known. + """ + return self._caused_by + + @caused_by.setter + def caused_by(self, value: int) -> None: + """ + Constructor identifier of the request that caused the error, if known. + """ + self._caused_by = value + @classmethod def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self: if m := re.search(r"-?\d+", error.error_message): @@ -175,6 +189,14 @@ class BadMessageError(ValueError): def fatal(self) -> bool: return self._code not in NON_FATAL_MSG_IDS + @property + def caused_by(self) -> Optional[int]: + return self._caused_by + + @caused_by.setter + def caused_by(self, value: int) -> None: + self._caused_by = value + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index bdfa5e96..c93bb223 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -181,6 +181,14 @@ class Sender: _read_buffer: bytearray _write_drain_pending: bool + @property + def mtp(self) -> Mtp: + return self._mtp + + @mtp.setter + def mtp(self, value: Mtp) -> None: + self._mtp = value + @classmethod async def connect( cls, @@ -412,20 +420,21 @@ class Sender: results = self._mtp.deserialize(self._mtp_buffer) for result in results: - if isinstance(result, Update): - self._process_update(result.body) - elif isinstance(result, RpcResult): - self._process_result(result) - elif isinstance(result, RpcError): - self._process_error(result) - elif isinstance(result, BadMessageError): - self._process_bad_message(result) - elif isinstance(result, DeserializationFailure): - self._process_deserialize_error(result) - else: - raise RuntimeError( - f"unexpected result type {type(result).__name__}: {result}" - ) + match result: + case Update(body=body): + self._process_update(body) + case RpcResult(): + self._process_result(result) + case RpcError(): + self._process_error(result) + case BadMessageError(): + self._process_bad_message(result) + case DeserializationFailure(): + self._process_deserialize_error(result) + case _: + raise RuntimeError( + f"unexpected result type {type(result).__name__}: {result}" + ) def _process_update(self, update: bytes | bytearray | memoryview) -> None: try: @@ -484,7 +493,7 @@ class Sender: req = self._pop_request(result.msg_id) if req: - result._caused_by = struct.unpack_from(" Sender: time_offset = finished.time_offset first_salt = finished.first_salt - sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) + sender.mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) return sender