mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-09-18 18:02:51 +03:00
Add 'caused_by' property to RpcError and BadMessageError; refactor Sender to use it
This commit is contained in:
parent
5fe17a17e2
commit
ecfb263e41
|
@ -96,6 +96,20 @@ class RpcError(ValueError):
|
||||||
"""
|
"""
|
||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def caused_by(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Constructor identifier of the request that caused the error, if known.
|
||||||
|
"""
|
||||||
|
return self._caused_by
|
||||||
|
|
||||||
|
@caused_by.setter
|
||||||
|
def caused_by(self, value: int) -> None:
|
||||||
|
"""
|
||||||
|
Constructor identifier of the request that caused the error, if known.
|
||||||
|
"""
|
||||||
|
self._caused_by = value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self:
|
def _from_mtproto_error(cls, error: GeneratedRpcError) -> Self:
|
||||||
if m := re.search(r"-?\d+", error.error_message):
|
if m := re.search(r"-?\d+", error.error_message):
|
||||||
|
@ -175,6 +189,14 @@ class BadMessageError(ValueError):
|
||||||
def fatal(self) -> bool:
|
def fatal(self) -> bool:
|
||||||
return self._code not in NON_FATAL_MSG_IDS
|
return self._code not in NON_FATAL_MSG_IDS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def caused_by(self) -> Optional[int]:
|
||||||
|
return self._caused_by
|
||||||
|
|
||||||
|
@caused_by.setter
|
||||||
|
def caused_by(self, value: int) -> None:
|
||||||
|
self._caused_by = value
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
|
@ -181,6 +181,14 @@ class Sender:
|
||||||
_read_buffer: bytearray
|
_read_buffer: bytearray
|
||||||
_write_drain_pending: bool
|
_write_drain_pending: bool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mtp(self) -> Mtp:
|
||||||
|
return self._mtp
|
||||||
|
|
||||||
|
@mtp.setter
|
||||||
|
def mtp(self, value: Mtp) -> None:
|
||||||
|
self._mtp = value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def connect(
|
async def connect(
|
||||||
cls,
|
cls,
|
||||||
|
@ -412,20 +420,21 @@ class Sender:
|
||||||
results = self._mtp.deserialize(self._mtp_buffer)
|
results = self._mtp.deserialize(self._mtp_buffer)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Update):
|
match result:
|
||||||
self._process_update(result.body)
|
case Update(body=body):
|
||||||
elif isinstance(result, RpcResult):
|
self._process_update(body)
|
||||||
self._process_result(result)
|
case RpcResult():
|
||||||
elif isinstance(result, RpcError):
|
self._process_result(result)
|
||||||
self._process_error(result)
|
case RpcError():
|
||||||
elif isinstance(result, BadMessageError):
|
self._process_error(result)
|
||||||
self._process_bad_message(result)
|
case BadMessageError():
|
||||||
elif isinstance(result, DeserializationFailure):
|
self._process_bad_message(result)
|
||||||
self._process_deserialize_error(result)
|
case DeserializationFailure():
|
||||||
else:
|
self._process_deserialize_error(result)
|
||||||
raise RuntimeError(
|
case _:
|
||||||
f"unexpected result type {type(result).__name__}: {result}"
|
raise RuntimeError(
|
||||||
)
|
f"unexpected result type {type(result).__name__}: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
|
def _process_update(self, update: bytes | bytearray | memoryview) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -484,7 +493,7 @@ class Sender:
|
||||||
req = self._pop_request(result.msg_id)
|
req = self._pop_request(result.msg_id)
|
||||||
|
|
||||||
if req:
|
if req:
|
||||||
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
result.caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
req.result.set_exception(result)
|
req.result.set_exception(result)
|
||||||
else:
|
else:
|
||||||
self._logger.warning(
|
self._logger.warning(
|
||||||
|
@ -511,7 +520,7 @@ class Sender:
|
||||||
result.msg_id,
|
result.msg_id,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
result._caused_by = struct.unpack_from("<I", req.body)[0]
|
result.caused_by = struct.unpack_from("<I", req.body)[0]
|
||||||
req.result.set_exception(result)
|
req.result.set_exception(result)
|
||||||
|
|
||||||
def _process_deserialize_error(self, failure: DeserializationFailure):
|
def _process_deserialize_error(self, failure: DeserializationFailure):
|
||||||
|
@ -600,5 +609,5 @@ async def generate_auth_key(sender: Sender) -> Sender:
|
||||||
time_offset = finished.time_offset
|
time_offset = finished.time_offset
|
||||||
first_salt = finished.first_salt
|
first_salt = finished.first_salt
|
||||||
|
|
||||||
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
|
sender.mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
|
||||||
return sender
|
return sender
|
||||||
|
|
Loading…
Reference in New Issue
Block a user