From c7d1a369697aeb5a313e2d3196b8408f6ec6021b Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sat, 16 Mar 2024 14:06:12 +0100 Subject: [PATCH] Return serialized container MsgId on finalize --- .../telethon/_impl/mtproto/mtp/encrypted.py | 21 ++++++------- .../src/telethon/_impl/mtproto/mtp/plain.py | 9 ++++-- .../src/telethon/_impl/mtproto/mtp/types.py | 19 +++++++++--- client/src/telethon/_impl/mtsender/sender.py | 5 ++-- client/tests/mtproto_test.py | 30 ++++++++++++------- 5 files changed, 54 insertions(+), 30 deletions(-) diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 4031aa7f..05af288c 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -204,9 +204,9 @@ class Encrypted(Mtp): def _get_current_salt(self) -> int: return self._salts[-1].salt if self._salts else 0 - def _finalize_plain(self) -> bytes: + def _finalize_plain(self) -> Optional[Tuple[MsgId, bytes]]: if not self._msg_count: - return b"" + return None if self._msg_count == 1: del self._buffer[:CONTAINER_HEADER_LEN] @@ -216,7 +216,7 @@ class Encrypted(Mtp): ) if self._msg_count == 1: - container_msg_id: Union[Type[Single], int] = Single + container_msg_id = self._last_msg_id else: container_msg_id = self._get_new_msg_id() self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack( @@ -235,7 +235,7 @@ class Encrypted(Mtp): self._msg_count = 0 result = bytes(self._buffer) self._buffer.clear() - return result + return MsgId(container_msg_id), result def _process_message(self, message: Message) -> None: if message_requires_ack(message): @@ -465,12 +465,13 @@ class Encrypted(Mtp): return self._serialize_msg(body, True) - def finalize(self) -> bytes: - buffer = self._finalize_plain() - if not buffer: - return buffer - else: - return encrypt_data_v2(buffer, self._auth_key) + def finalize(self) -> Optional[Tuple[MsgId, bytes]]: + result = self._finalize_plain() + if not result: + return None + + msg_id, buffer = result + return msg_id, encrypt_data_v2(buffer, self._auth_key) def deserialize(self, payload: bytes) -> Deserialization: check_message_buffer(payload) diff --git a/client/src/telethon/_impl/mtproto/mtp/plain.py b/client/src/telethon/_impl/mtproto/mtp/plain.py index 0bd13201..89dda0d3 100644 --- a/client/src/telethon/_impl/mtproto/mtp/plain.py +++ b/client/src/telethon/_impl/mtproto/mtp/plain.py @@ -1,5 +1,5 @@ import struct -from typing import Optional +from typing import Optional, Tuple from ..utils import check_message_buffer from .types import Deserialization, MsgId, Mtp @@ -23,10 +23,13 @@ class Plain(Mtp): self._buffer += request # message_data return msg_id - def finalize(self) -> bytes: + def finalize(self) -> Optional[Tuple[MsgId, bytes]]: + if not self._buffer: + return None + result = bytes(self._buffer) self._buffer.clear() - return result + return MsgId(0), result def deserialize(self, payload: bytes) -> Deserialization: check_message_buffer(payload) diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 4aea6297..819b0b42 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -165,12 +165,23 @@ class Deserialization: class Mtp(ABC): @abstractmethod def push(self, request: bytes) -> Optional[MsgId]: - pass + """ + Push a request's body to the internal buffer. + + On success, return the serialized message identifier. + """ @abstractmethod - def finalize(self) -> bytes: - pass + def finalize(self) -> Optional[Tuple[MsgId, bytes]]: + """ + Finalize the buffer of serialized requests. + + If the buffer is empty, :data:`None` is returned. + Otherwise, the message identifier for the entire buffer and the serialized buffer are returned. + """ @abstractmethod def deserialize(self, payload: bytes) -> Deserialization: - pass + """ + Deserialize incoming buffer payload. + """ diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 1a470910..57de56f0 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -271,8 +271,9 @@ class Sender: else: break - mtp_buffer = self._mtp.finalize() - if mtp_buffer: + result = self._mtp.finalize() + if result: + _, mtp_buffer = result self._transport.pack(mtp_buffer, self._writer.write) self._write_drain_pending = True diff --git a/client/tests/mtproto_test.py b/client/tests/mtproto_test.py index ff65932f..4f0fd518 100644 --- a/client/tests/mtproto_test.py +++ b/client/tests/mtproto_test.py @@ -1,8 +1,10 @@ import struct +from typing import Optional, Tuple from pytest import raises from telethon._impl.crypto import AuthKey from telethon._impl.mtproto import Encrypted, Plain, RpcError +from telethon._impl.mtproto.mtp.types import MsgId from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError @@ -47,14 +49,20 @@ def test_rpc_error_parsing() -> None: PLAIN_REQUEST = b"Hey!" +def unwrap_finalize(finalized: Optional[Tuple[MsgId, bytes]]) -> bytes: + assert finalized is not None + _, buffer = finalized + return buffer + + def test_plain_finalize_clears_buffer() -> None: mtp = Plain() mtp.push(PLAIN_REQUEST) - assert len(mtp.finalize()) == 24 + assert len(unwrap_finalize(mtp.finalize())) == 24 mtp.push(PLAIN_REQUEST) - assert len(mtp.finalize()) == 24 + assert len(unwrap_finalize(mtp.finalize())) == 24 def test_plain_only_one_push_allowed() -> None: @@ -90,7 +98,7 @@ def test_serialization_has_salt_client_id() -> None: mtp = Encrypted(auth_key()) mtp.push(REQUEST) - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) # salt assert buffer[0:8] == bytes(8) @@ -104,7 +112,7 @@ def test_correct_single_serialization() -> None: mtp = Encrypted(auth_key()) assert mtp.push(REQUEST) is not None - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1) @@ -114,7 +122,7 @@ def test_correct_multi_serialization() -> None: assert mtp.push(REQUEST) is not None assert mtp.push(REQUEST_B) is not None - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) buffer = buffer[MESSAGE_PREFIX_LEN:] # container msg_id @@ -138,7 +146,7 @@ def test_correct_single_large_serialization() -> None: data = bytes(0x7F for _ in range(768 * 1024)) assert mtp.push(data) is not None - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) buffer = buffer[MESSAGE_PREFIX_LEN:] assert len(buffer) == 16 + len(data) @@ -151,7 +159,7 @@ def test_correct_multi_large_serialization() -> None: assert mtp.push(data) is not None assert mtp.push(data) is None - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) buffer = buffer[MESSAGE_PREFIX_LEN:] assert len(buffer) == 16 + len(data) @@ -173,22 +181,22 @@ def test_non_padded_payload_panics() -> None: def test_no_compression_is_honored() -> None: mtp = Encrypted(auth_key(), compression_threshold=None) mtp.push(bytes(512 * 1024)) - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) assert GZIP_PACKED_HEADER not in buffer def test_some_compression() -> None: mtp = Encrypted(auth_key(), compression_threshold=768 * 1024) mtp.push(bytes(512 * 1024)) - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) assert GZIP_PACKED_HEADER not in buffer mtp = Encrypted(auth_key(), compression_threshold=256 * 1024) mtp.push(bytes(512 * 1024)) - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) assert GZIP_PACKED_HEADER in buffer mtp = Encrypted(auth_key()) mtp.push(bytes(512 * 1024)) - buffer = mtp._finalize_plain() + buffer = unwrap_finalize(mtp._finalize_plain()) assert GZIP_PACKED_HEADER in buffer