finalize does not need to return last_msg_id

This commit is contained in:
Lonami Exo 2024-06-07 22:14:40 +02:00
parent 4536667a6a
commit 94048d9102
2 changed files with 10 additions and 11 deletions

View File

@ -196,7 +196,7 @@ class Encrypted(Mtp):
def _get_current_salt(self) -> int: def _get_current_salt(self) -> int:
return self._salts[-1].salt if self._salts else 0 return self._salts[-1].salt if self._salts else 0
def _finalize_plain(self) -> Optional[tuple[MsgId, bytes]]: def _finalize_plain(self) -> Optional[bytes]:
if not self._msg_count: if not self._msg_count:
return None return None
@ -207,13 +207,10 @@ class Encrypted(Mtp):
"<qq", self._get_current_salt(), self._client_id "<qq", self._get_current_salt(), self._client_id
) )
if self._msg_count == 1: if self._msg_count != 1:
container_msg_id = self._last_msg_id
else:
container_msg_id = self._get_new_msg_id()
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack( self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
"<qiiIi", "<qiiIi",
container_msg_id, self._get_new_msg_id(),
self._get_seq_no(False), self._get_seq_no(False),
len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8, len(self._buffer) - HEADER_LEN - CONTAINER_HEADER_LEN + 8,
MsgContainer.constructor_id(), MsgContainer.constructor_id(),
@ -223,7 +220,7 @@ class Encrypted(Mtp):
self._msg_count = 0 self._msg_count = 0
result = bytes(self._buffer) result = bytes(self._buffer)
self._buffer.clear() self._buffer.clear()
return MsgId(container_msg_id), result return result
def _process_message(self, message: Message) -> None: def _process_message(self, message: Message) -> None:
if message_requires_ack(message): if message_requires_ack(message):
@ -436,8 +433,7 @@ class Encrypted(Mtp):
if not result: if not result:
return None return None
msg_id, buffer = result return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key)
return msg_id, encrypt_data_v2(buffer, self._auth_key)
def deserialize( def deserialize(
self, payload: bytes | bytearray | memoryview self, payload: bytes | bytearray | memoryview

View File

@ -49,9 +49,12 @@ def test_rpc_error_parsing() -> None:
PLAIN_REQUEST = b"Hey!" PLAIN_REQUEST = b"Hey!"
def unwrap_finalize(finalized: Optional[tuple[MsgId, bytes]]) -> bytes: def unwrap_finalize(finalized: Optional[tuple[MsgId, bytes] | bytes]) -> bytes:
assert finalized is not None assert finalized is not None
_, buffer = finalized if isinstance(finalized, tuple):
_, buffer = finalized
else:
buffer = finalized
return buffer return buffer