diff --git a/client/src/telethon/_impl/mtproto/transport/abcs.py b/client/src/telethon/_impl/mtproto/transport/abcs.py index 432ed322..de24ebc0 100644 --- a/client/src/telethon/_impl/mtproto/transport/abcs.py +++ b/client/src/telethon/_impl/mtproto/transport/abcs.py @@ -1,9 +1,15 @@ from abc import ABC, abstractmethod +from typing import Callable + + +OutFn = Callable[[bytes | bytearray | memoryview], None] class Transport(ABC): + # Python's stream writer has a synchronous write (buffer append) followed + # by drain. The buffer is externally managed, so `write` is used as input. @abstractmethod - def pack(self, input: bytes, output: bytearray) -> None: + def pack(self, input: bytes, write: OutFn) -> None: pass @abstractmethod diff --git a/client/src/telethon/_impl/mtproto/transport/abridged.py b/client/src/telethon/_impl/mtproto/transport/abridged.py index fda75100..9d43fe32 100644 --- a/client/src/telethon/_impl/mtproto/transport/abridged.py +++ b/client/src/telethon/_impl/mtproto/transport/abridged.py @@ -1,6 +1,6 @@ import struct -from .abcs import MissingBytes, Transport +from .abcs import MissingBytes, OutFn, Transport class Abridged(Transport): @@ -22,19 +22,19 @@ class Abridged(Transport): def __init__(self) -> None: self._init = False - def pack(self, input: bytes, output: bytearray) -> None: + def pack(self, input: bytes, write: OutFn) -> None: assert len(input) % 4 == 0 if not self._init: - output += b"\xef" + write(b"\xef") self._init = True length = len(input) // 4 if length < 127: - output += struct.pack(" int: if not input: diff --git a/client/src/telethon/_impl/mtproto/transport/full.py b/client/src/telethon/_impl/mtproto/transport/full.py index 9a6cde7f..fd59860f 100644 --- a/client/src/telethon/_impl/mtproto/transport/full.py +++ b/client/src/telethon/_impl/mtproto/transport/full.py @@ -1,7 +1,7 @@ import struct from zlib import crc32 -from .abcs import MissingBytes, Transport +from .abcs import MissingBytes, OutFn, Transport class Full(Transport): @@ -24,13 +24,15 @@ class Full(Transport): self._send_seq = 0 self._recv_seq = 0 - def pack(self, input: bytes, output: bytearray) -> None: + def pack(self, input: bytes, write: OutFn) -> None: assert len(input) % 4 == 0 length = len(input) + 12 - output += struct.pack(" int: diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 9dd8e188..241d094f 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -1,6 +1,6 @@ import struct -from .abcs import MissingBytes, Transport +from .abcs import MissingBytes, OutFn, Transport class Intermediate(Transport): @@ -22,15 +22,15 @@ class Intermediate(Transport): def __init__(self) -> None: self._init = False - def pack(self, input: bytes, output: bytearray) -> None: + def pack(self, input: bytes, write: OutFn) -> None: assert len(input) % 4 == 0 if not self._init: - output += b"\xee\xee\xee\xee" + write(b"\xee\xee\xee\xee") self._init = True - output += struct.pack(" int: if len(input) < 4: diff --git a/client/tests/transport/abridged_test.py b/client/tests/transport/abridged_test.py index 7733babf..71af5dc1 100644 --- a/client/tests/transport/abridged_test.py +++ b/client/tests/transport/abridged_test.py @@ -4,9 +4,16 @@ from pytest import raises from telethon._impl.mtproto.transport.abridged import Abridged +class Output(bytearray): + __slots__ = () + + def __call__(self, data: bytes) -> None: + self += data + + def setup_pack(n: int) -> Tuple[Abridged, bytes, bytearray]: input = bytes(x & 0xFF for x in range(n)) - return Abridged(), input, bytearray() + return Abridged(), input, Output() def test_pack_empty() -> None: diff --git a/client/tests/transport/full_test.py b/client/tests/transport/full_test.py index 65d9c28e..b85827fc 100644 --- a/client/tests/transport/full_test.py +++ b/client/tests/transport/full_test.py @@ -4,9 +4,16 @@ from pytest import raises from telethon._impl.mtproto.transport.full import Full +class Output(bytearray): + __slots__ = () + + def __call__(self, data: bytes) -> None: + self += data + + def setup_pack(n: int) -> Tuple[Full, bytes, bytearray]: input = bytes(x & 0xFF for x in range(n)) - return Full(), input, bytearray() + return Full(), input, Output() def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]: diff --git a/client/tests/transport/intermediate_test.py b/client/tests/transport/intermediate_test.py index 657f1fe0..b3a6efe8 100644 --- a/client/tests/transport/intermediate_test.py +++ b/client/tests/transport/intermediate_test.py @@ -4,9 +4,16 @@ from pytest import raises from telethon._impl.mtproto.transport.intermediate import Intermediate +class Output(bytearray): + __slots__ = () + + def __call__(self, data: bytes) -> None: + self += data + + def setup_pack(n: int) -> Tuple[Intermediate, bytes, bytearray]: input = bytes(x & 0xFF for x in range(n)) - return Intermediate(), input, bytearray() + return Intermediate(), input, Output() def test_pack_empty() -> None: