mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 09:57:29 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			207 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			207 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import struct
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
from pytest import raises
 | 
						|
 | 
						|
from telethon._impl.crypto import AuthKey
 | 
						|
from telethon._impl.mtproto import Encrypted, Plain, RpcError
 | 
						|
from telethon._impl.mtproto.mtp.types import MsgId
 | 
						|
from telethon._impl.tl.mtproto.types import RpcError as GeneratedRpcError
 | 
						|
 | 
						|
 | 
						|
def test_rpc_error_parsing() -> None:
 | 
						|
    assert RpcError._from_mtproto_error(
 | 
						|
        GeneratedRpcError(
 | 
						|
            error_code=400,
 | 
						|
            error_message="CHAT_INVALID",
 | 
						|
        )
 | 
						|
    ) == RpcError(
 | 
						|
        code=400,
 | 
						|
        name="CHAT_INVALID",
 | 
						|
        value=None,
 | 
						|
        caused_by=None,
 | 
						|
    )
 | 
						|
 | 
						|
    assert RpcError._from_mtproto_error(
 | 
						|
        GeneratedRpcError(
 | 
						|
            error_code=420,
 | 
						|
            error_message="FLOOD_WAIT_31",
 | 
						|
        )
 | 
						|
    ) == RpcError(
 | 
						|
        code=420,
 | 
						|
        name="FLOOD_WAIT",
 | 
						|
        value=31,
 | 
						|
        caused_by=None,
 | 
						|
    )
 | 
						|
 | 
						|
    assert RpcError._from_mtproto_error(
 | 
						|
        GeneratedRpcError(
 | 
						|
            error_code=500,
 | 
						|
            error_message="INTERDC_2_CALL_ERROR",
 | 
						|
        )
 | 
						|
    ) == RpcError(
 | 
						|
        code=500,
 | 
						|
        name="INTERDC_CALL_ERROR",
 | 
						|
        value=2,
 | 
						|
        caused_by=None,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
PLAIN_REQUEST = b"Hey!"
 | 
						|
 | 
						|
 | 
						|
def unwrap_finalize(finalized: Optional[tuple[MsgId, bytes] | bytes]) -> bytes:
 | 
						|
    assert finalized is not None
 | 
						|
    if isinstance(finalized, tuple):
 | 
						|
        _, buffer = finalized
 | 
						|
    else:
 | 
						|
        buffer = finalized
 | 
						|
    return buffer
 | 
						|
 | 
						|
 | 
						|
def test_plain_finalize_clears_buffer() -> None:
 | 
						|
    mtp = Plain()
 | 
						|
 | 
						|
    mtp.push(PLAIN_REQUEST)
 | 
						|
    assert len(unwrap_finalize(mtp.finalize())) == 24
 | 
						|
 | 
						|
    mtp.push(PLAIN_REQUEST)
 | 
						|
    assert len(unwrap_finalize(mtp.finalize())) == 24
 | 
						|
 | 
						|
 | 
						|
def test_plain_only_one_push_allowed() -> None:
 | 
						|
    mtp = Plain()
 | 
						|
 | 
						|
    assert mtp.push(PLAIN_REQUEST) is not None
 | 
						|
    assert mtp.push(PLAIN_REQUEST) is None
 | 
						|
 | 
						|
 | 
						|
MESSAGE_PREFIX_LEN = 8 + 8  # salt + client_id
 | 
						|
GZIP_PACKED_HEADER = b"\xa1\xcf\x72\x30"
 | 
						|
MSG_CONTAINER_HEADER = b"\xdc\xf8\xf1\x73"
 | 
						|
REQUEST = b"Hey!"
 | 
						|
REQUEST_B = b"Bye!"
 | 
						|
 | 
						|
 | 
						|
def auth_key() -> AuthKey:
 | 
						|
    return AuthKey.from_bytes(bytes(256))
 | 
						|
 | 
						|
 | 
						|
def ensure_buffer_is_message(buffer: bytes, body: bytes, seq_no: int) -> None:
 | 
						|
    # msg_id, based on time
 | 
						|
    assert buffer[0:8] != bytes(8)
 | 
						|
    # seq_no, sequential odd number
 | 
						|
    assert buffer[8:12] == struct.pack("<i", seq_no)
 | 
						|
    # bytes, body length
 | 
						|
    assert buffer[12:16] == struct.pack("<i", len(body))
 | 
						|
    # body
 | 
						|
    assert buffer[16:] == body
 | 
						|
 | 
						|
 | 
						|
def test_serialization_has_salt_client_id() -> None:
 | 
						|
    mtp = Encrypted(auth_key())
 | 
						|
 | 
						|
    mtp.push(REQUEST)
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
 | 
						|
    # salt
 | 
						|
    assert buffer[0:8] == bytes(8)
 | 
						|
    # client_id
 | 
						|
    assert buffer[8:16] != bytes(8)
 | 
						|
    # message
 | 
						|
    ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
 | 
						|
 | 
						|
 | 
						|
def test_correct_single_serialization() -> None:
 | 
						|
    mtp = Encrypted(auth_key())
 | 
						|
 | 
						|
    assert mtp.push(REQUEST) is not None
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
 | 
						|
    ensure_buffer_is_message(buffer[MESSAGE_PREFIX_LEN:], REQUEST, 1)
 | 
						|
 | 
						|
 | 
						|
def test_correct_multi_serialization() -> None:
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
						|
 | 
						|
    assert mtp.push(REQUEST) is not None
 | 
						|
    assert mtp.push(REQUEST_B) is not None
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
						|
 | 
						|
    # container msg_id
 | 
						|
    assert buffer[0:8] != bytes(8)
 | 
						|
    # seq_no (after 1, 3 content-related comes 4)
 | 
						|
    assert buffer[8:12] == b"\x04\0\0\0"
 | 
						|
    # body length
 | 
						|
    assert buffer[12:16] == b"\x30\0\0\0"
 | 
						|
 | 
						|
    # container constructor_id
 | 
						|
    assert buffer[16:20] == MSG_CONTAINER_HEADER
 | 
						|
    # message count
 | 
						|
    assert buffer[20:24] == b"\x02\0\0\0"
 | 
						|
 | 
						|
    ensure_buffer_is_message(buffer[24:44], REQUEST, 1)
 | 
						|
    ensure_buffer_is_message(buffer[44:], REQUEST_B, 3)
 | 
						|
 | 
						|
 | 
						|
def test_correct_single_large_serialization() -> None:
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
						|
    data = bytes(0x7F for _ in range(768 * 1024))
 | 
						|
 | 
						|
    assert mtp.push(data) is not None
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
 | 
						|
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
						|
    assert len(buffer) == 16 + len(data)
 | 
						|
 | 
						|
 | 
						|
def test_correct_multi_large_serialization() -> None:
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
						|
    data = bytes(0x7F for _ in range(768 * 1024))
 | 
						|
 | 
						|
    assert mtp.push(data) is not None
 | 
						|
    assert mtp.push(data) is None
 | 
						|
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    buffer = buffer[MESSAGE_PREFIX_LEN:]
 | 
						|
    assert len(buffer) == 16 + len(data)
 | 
						|
 | 
						|
 | 
						|
def test_large_payload_panics() -> None:
 | 
						|
    mtp = Encrypted(auth_key())
 | 
						|
 | 
						|
    with raises(AssertionError):
 | 
						|
        mtp.push(bytes(2 * 1024 * 1024))
 | 
						|
 | 
						|
 | 
						|
def test_non_padded_payload_panics() -> None:
 | 
						|
    mtp = Encrypted(auth_key())
 | 
						|
 | 
						|
    with raises(AssertionError):
 | 
						|
        mtp.push(b"\x01\x02\x03")
 | 
						|
 | 
						|
 | 
						|
def test_no_compression_is_honored() -> None:
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=None)
 | 
						|
    mtp.push(bytes(512 * 1024))
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    assert GZIP_PACKED_HEADER not in buffer
 | 
						|
 | 
						|
 | 
						|
def test_some_compression() -> None:
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=768 * 1024)
 | 
						|
    mtp.push(bytes(512 * 1024))
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    assert GZIP_PACKED_HEADER not in buffer
 | 
						|
 | 
						|
    mtp = Encrypted(auth_key(), compression_threshold=256 * 1024)
 | 
						|
    mtp.push(bytes(512 * 1024))
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    assert GZIP_PACKED_HEADER in buffer
 | 
						|
 | 
						|
    mtp = Encrypted(auth_key())
 | 
						|
    mtp.push(bytes(512 * 1024))
 | 
						|
    buffer = unwrap_finalize(mtp._finalize_plain())
 | 
						|
    assert GZIP_PACKED_HEADER in buffer
 |