mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Return serialized container MsgId on finalize
This commit is contained in:
		
							parent
							
								
									6fd3eb2ee6
								
							
						
					
					
						commit
						c7d1a36969
					
				| 
						 | 
					@ -204,9 +204,9 @@ class Encrypted(Mtp):
 | 
				
			||||||
    def _get_current_salt(self) -> int:
 | 
					    def _get_current_salt(self) -> int:
 | 
				
			||||||
        return self._salts[-1].salt if self._salts else 0
 | 
					        return self._salts[-1].salt if self._salts else 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _finalize_plain(self) -> bytes:
 | 
					    def _finalize_plain(self) -> Optional[Tuple[MsgId, bytes]]:
 | 
				
			||||||
        if not self._msg_count:
 | 
					        if not self._msg_count:
 | 
				
			||||||
            return b""
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._msg_count == 1:
 | 
					        if self._msg_count == 1:
 | 
				
			||||||
            del self._buffer[:CONTAINER_HEADER_LEN]
 | 
					            del self._buffer[:CONTAINER_HEADER_LEN]
 | 
				
			||||||
| 
						 | 
					@ -216,7 +216,7 @@ class Encrypted(Mtp):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._msg_count == 1:
 | 
					        if self._msg_count == 1:
 | 
				
			||||||
            container_msg_id: Union[Type[Single], int] = Single
 | 
					            container_msg_id = self._last_msg_id
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            container_msg_id = self._get_new_msg_id()
 | 
					            container_msg_id = self._get_new_msg_id()
 | 
				
			||||||
            self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
 | 
					            self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
 | 
				
			||||||
| 
						 | 
					@ -235,7 +235,7 @@ class Encrypted(Mtp):
 | 
				
			||||||
        self._msg_count = 0
 | 
					        self._msg_count = 0
 | 
				
			||||||
        result = bytes(self._buffer)
 | 
					        result = bytes(self._buffer)
 | 
				
			||||||
        self._buffer.clear()
 | 
					        self._buffer.clear()
 | 
				
			||||||
        return result
 | 
					        return MsgId(container_msg_id), result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _process_message(self, message: Message) -> None:
 | 
					    def _process_message(self, message: Message) -> None:
 | 
				
			||||||
        if message_requires_ack(message):
 | 
					        if message_requires_ack(message):
 | 
				
			||||||
| 
						 | 
					@ -465,12 +465,13 @@ class Encrypted(Mtp):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return self._serialize_msg(body, True)
 | 
					        return self._serialize_msg(body, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def finalize(self) -> bytes:
 | 
					    def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
 | 
				
			||||||
        buffer = self._finalize_plain()
 | 
					        result = self._finalize_plain()
 | 
				
			||||||
        if not buffer:
 | 
					        if not result:
 | 
				
			||||||
            return buffer
 | 
					            return None
 | 
				
			||||||
        else:
 | 
					
 | 
				
			||||||
            return encrypt_data_v2(buffer, self._auth_key)
 | 
					        msg_id, buffer = result
 | 
				
			||||||
 | 
					        return msg_id, encrypt_data_v2(buffer, self._auth_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def deserialize(self, payload: bytes) -> Deserialization:
 | 
					    def deserialize(self, payload: bytes) -> Deserialization:
 | 
				
			||||||
        check_message_buffer(payload)
 | 
					        check_message_buffer(payload)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..utils import check_message_buffer
 | 
					from ..utils import check_message_buffer
 | 
				
			||||||
from .types import Deserialization, MsgId, Mtp
 | 
					from .types import Deserialization, MsgId, Mtp
 | 
				
			||||||
| 
						 | 
					@ -23,10 +23,13 @@ class Plain(Mtp):
 | 
				
			||||||
        self._buffer += request  # message_data
 | 
					        self._buffer += request  # message_data
 | 
				
			||||||
        return msg_id
 | 
					        return msg_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def finalize(self) -> bytes:
 | 
					    def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
 | 
				
			||||||
 | 
					        if not self._buffer:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        result = bytes(self._buffer)
 | 
					        result = bytes(self._buffer)
 | 
				
			||||||
        self._buffer.clear()
 | 
					        self._buffer.clear()
 | 
				
			||||||
        return result
 | 
					        return MsgId(0), result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def deserialize(self, payload: bytes) -> Deserialization:
 | 
					    def deserialize(self, payload: bytes) -> Deserialization:
 | 
				
			||||||
        check_message_buffer(payload)
 | 
					        check_message_buffer(payload)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -165,12 +165,23 @@ class Deserialization:
 | 
				
			||||||
class Mtp(ABC):
 | 
					class Mtp(ABC):
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def push(self, request: bytes) -> Optional[MsgId]:
 | 
					    def push(self, request: bytes) -> Optional[MsgId]:
 | 
				
			||||||
        pass
 | 
					        """
 | 
				
			||||||
 | 
					        Push a request's body to the internal buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        On success, return the serialized message identifier.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def finalize(self) -> bytes:
 | 
					    def finalize(self) -> Optional[Tuple[MsgId, bytes]]:
 | 
				
			||||||
        pass
 | 
					        """
 | 
				
			||||||
 | 
					        Finalize the buffer of serialized requests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If the buffer is empty, :data:`None` is returned.
 | 
				
			||||||
 | 
					        Otherwise, the message identifier for the entire buffer and the serialized buffer are returned.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def deserialize(self, payload: bytes) -> Deserialization:
 | 
					    def deserialize(self, payload: bytes) -> Deserialization:
 | 
				
			||||||
        pass
 | 
					        """
 | 
				
			||||||
 | 
					        Deserialize incoming buffer payload.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -271,8 +271,9 @@ class Sender:
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        mtp_buffer = self._mtp.finalize()
 | 
					        result = self._mtp.finalize()
 | 
				
			||||||
        if mtp_buffer:
 | 
					        if result:
 | 
				
			||||||
 | 
					            _, mtp_buffer = result
 | 
				
			||||||
            self._transport.pack(mtp_buffer, self._writer.write)
 | 
					            self._transport.pack(mtp_buffer, self._writer.write)
 | 
				
			||||||
            self._write_drain_pending = True
 | 
					            self._write_drain_pending = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,10 @@
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pytest import raises
 | 
					from pytest import raises
 | 
				
			||||||
from telethon._impl.crypto import AuthKey
 | 
					from telethon._impl.crypto import AuthKey
 | 
				
			||||||
from telethon._impl.mtproto import Encrypted, Plain, RpcError
 | 
					from telethon._impl.mtproto import Encrypted, Plain, RpcError
 | 
				
			||||||
 | 
					from telethon._impl.mtproto.mtp.types import MsgId
 | 
				
			||||||
from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError
 | 
					from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,14 +49,20 @@ def test_rpc_error_parsing() -> None:
 | 
				
			||||||
PLAIN_REQUEST = b"Hey!"
 | 
					PLAIN_REQUEST = b"Hey!"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def unwrap_finalize(finalized: Optional[Tuple[MsgId, bytes]]) -> bytes:
 | 
				
			||||||
 | 
					    assert finalized is not None
 | 
				
			||||||
 | 
					    _, buffer = finalized
 | 
				
			||||||
 | 
					    return buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_plain_finalize_clears_buffer() -> None:
 | 
					def test_plain_finalize_clears_buffer() -> None:
 | 
				
			||||||
    mtp = Plain()
 | 
					    mtp = Plain()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mtp.push(PLAIN_REQUEST)
 | 
					    mtp.push(PLAIN_REQUEST)
 | 
				
			||||||
    assert len(mtp.finalize()) == 24
 | 
					    assert len(unwrap_finalize(mtp.finalize())) == 24
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mtp.push(PLAIN_REQUEST)
 | 
					    mtp.push(PLAIN_REQUEST)
 | 
				
			||||||
    assert len(mtp.finalize()) == 24
 | 
					    assert len(unwrap_finalize(mtp.finalize())) == 24
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_plain_only_one_push_allowed() -> None:
 | 
					def test_plain_only_one_push_allowed() -> None:
 | 
				
			||||||
| 
						 | 
					@ -90,7 +98,7 @@ def test_serialization_has_salt_client_id() -> None:
 | 
				
			||||||
    mtp = Encrypted(auth_key())
 | 
					    mtp = Encrypted(auth_key())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mtp.push(REQUEST)
 | 
					    mtp.push(REQUEST)
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # salt
 | 
					    # salt
 | 
				
			||||||
    assert buffer[0:8] == bytes(8)
 | 
					    assert buffer[0:8] == bytes(8)
 | 
				
			||||||
| 
						 | 
					@ -104,7 +112,7 @@ def test_correct_single_serialization() -> None:
 | 
				
			||||||
    mtp = Encrypted(auth_key())
 | 
					    mtp = Encrypted(auth_key())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert mtp.push(REQUEST) is not None
 | 
					    assert mtp.push(REQUEST) is not None
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
 | 
					    ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -114,7 +122,7 @@ def test_correct_multi_serialization() -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert mtp.push(REQUEST) is not None
 | 
					    assert mtp.push(REQUEST) is not None
 | 
				
			||||||
    assert mtp.push(REQUEST_B) is not None
 | 
					    assert mtp.push(REQUEST_B) is not None
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
					    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # container msg_id
 | 
					    # container msg_id
 | 
				
			||||||
| 
						 | 
					@ -138,7 +146,7 @@ def test_correct_single_large_serialization() -> None:
 | 
				
			||||||
    data = bytes(0x7F for _ in range(768 * 1024))
 | 
					    data = bytes(0x7F for _ in range(768 * 1024))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert mtp.push(data) is not None
 | 
					    assert mtp.push(data) is not None
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
					    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
				
			||||||
    assert len(buffer) == 16 + len(data)
 | 
					    assert len(buffer) == 16 + len(data)
 | 
				
			||||||
| 
						 | 
					@ -151,7 +159,7 @@ def test_correct_multi_large_serialization() -> None:
 | 
				
			||||||
    assert mtp.push(data) is not None
 | 
					    assert mtp.push(data) is not None
 | 
				
			||||||
    assert mtp.push(data) is None
 | 
					    assert mtp.push(data) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
					    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
				
			||||||
    assert len(buffer) == 16 + len(data)
 | 
					    assert len(buffer) == 16 + len(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -173,22 +181,22 @@ def test_non_padded_payload_panics() -> None:
 | 
				
			||||||
def test_no_compression_is_honored() -> None:
 | 
					def test_no_compression_is_honored() -> None:
 | 
				
			||||||
    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
					    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
				
			||||||
    mtp.push(bytes(512 * 1024))
 | 
					    mtp.push(bytes(512 * 1024))
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    assert GZIP_PACKED_HEADER not in buffer
 | 
					    assert GZIP_PACKED_HEADER not in buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_some_compression() -> None:
 | 
					def test_some_compression() -> None:
 | 
				
			||||||
    mtp = Encrypted(auth_key(), compression_threshold=768 * 1024)
 | 
					    mtp = Encrypted(auth_key(), compression_threshold=768 * 1024)
 | 
				
			||||||
    mtp.push(bytes(512 * 1024))
 | 
					    mtp.push(bytes(512 * 1024))
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    assert GZIP_PACKED_HEADER not in buffer
 | 
					    assert GZIP_PACKED_HEADER not in buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mtp = Encrypted(auth_key(), compression_threshold=256 * 1024)
 | 
					    mtp = Encrypted(auth_key(), compression_threshold=256 * 1024)
 | 
				
			||||||
    mtp.push(bytes(512 * 1024))
 | 
					    mtp.push(bytes(512 * 1024))
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    assert GZIP_PACKED_HEADER in buffer
 | 
					    assert GZIP_PACKED_HEADER in buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mtp = Encrypted(auth_key())
 | 
					    mtp = Encrypted(auth_key())
 | 
				
			||||||
    mtp.push(bytes(512 * 1024))
 | 
					    mtp.push(bytes(512 * 1024))
 | 
				
			||||||
    buffer = mtp._finalize_plain()
 | 
					    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
				
			||||||
    assert GZIP_PACKED_HEADER in buffer
 | 
					    assert GZIP_PACKED_HEADER in buffer
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user