Change transports to pack into a write fn

This commit is contained in:
Lonami Exo 2023-08-30 16:37:04 +02:00
parent e12845c38b
commit d5e6dbe36b
7 changed files with 49 additions and 20 deletions

View File

@ -1,9 +1,15 @@
from abc import ABC, abstractmethod
from typing import Callable
OutFn = Callable[[bytes | bytearray | memoryview], None]
class Transport(ABC):
# Python's stream writer has a synchronous write (buffer append) followed
# by drain. The buffer is externally managed, so `write` is used as input.
@abstractmethod
def pack(self, input: bytes, output: bytearray) -> None:
def pack(self, input: bytes, write: OutFn) -> None:
pass
@abstractmethod

View File

@ -1,6 +1,6 @@
import struct
from .abcs import MissingBytes, Transport
from .abcs import MissingBytes, OutFn, Transport
class Abridged(Transport):
@ -22,19 +22,19 @@ class Abridged(Transport):
def __init__(self) -> None:
self._init = False
def pack(self, input: bytes, output: bytearray) -> None:
def pack(self, input: bytes, write: OutFn) -> None:
assert len(input) % 4 == 0
if not self._init:
output += b"\xef"
write(b"\xef")
self._init = True
length = len(input) // 4
if length < 127:
output += struct.pack("<b", length)
write(struct.pack("<b", length))
else:
output += struct.pack("<i", 0x7F | (length << 8))
output += input
write(struct.pack("<i", 0x7F | (length << 8)))
write(input)
def unpack(self, input: bytes, output: bytearray) -> int:
if not input:

View File

@ -1,7 +1,7 @@
import struct
from zlib import crc32
from .abcs import MissingBytes, Transport
from .abcs import MissingBytes, OutFn, Transport
class Full(Transport):
@ -24,13 +24,15 @@ class Full(Transport):
self._send_seq = 0
self._recv_seq = 0
def pack(self, input: bytes, output: bytearray) -> None:
def pack(self, input: bytes, write: OutFn) -> None:
assert len(input) % 4 == 0
length = len(input) + 12
output += struct.pack("<ii", length, self._send_seq)
output += input
output += struct.pack("<i", crc32(memoryview(output)[-(length - 4) :]))
# Unfortunately there's no hasher that can be updated multiple times,
# so a temporary buffer must be used to hash it all in one go.
tmp = struct.pack("<ii", length, self._send_seq) + input
write(tmp)
write(struct.pack("<I", crc32(tmp)))
self._send_seq += 1
def unpack(self, input: bytes, output: bytearray) -> int:

View File

@ -1,6 +1,6 @@
import struct
from .abcs import MissingBytes, Transport
from .abcs import MissingBytes, OutFn, Transport
class Intermediate(Transport):
@ -22,15 +22,15 @@ class Intermediate(Transport):
def __init__(self) -> None:
self._init = False
def pack(self, input: bytes, output: bytearray) -> None:
def pack(self, input: bytes, write: OutFn) -> None:
assert len(input) % 4 == 0
if not self._init:
output += b"\xee\xee\xee\xee"
write(b"\xee\xee\xee\xee")
self._init = True
output += struct.pack("<i", len(input))
output += input
write(struct.pack("<i", len(input)))
write(input)
def unpack(self, input: bytes, output: bytearray) -> int:
if len(input) < 4:

View File

@ -4,9 +4,16 @@ from pytest import raises
from telethon._impl.mtproto.transport.abridged import Abridged
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
self += data
def setup_pack(n: int) -> Tuple[Abridged, bytes, bytearray]:
input = bytes(x & 0xFF for x in range(n))
return Abridged(), input, bytearray()
return Abridged(), input, Output()
def test_pack_empty() -> None:

View File

@ -4,9 +4,16 @@ from pytest import raises
from telethon._impl.mtproto.transport.full import Full
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
self += data
def setup_pack(n: int) -> Tuple[Full, bytes, bytearray]:
input = bytes(x & 0xFF for x in range(n))
return Full(), input, bytearray()
return Full(), input, Output()
def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]:

View File

@ -4,9 +4,16 @@ from pytest import raises
from telethon._impl.mtproto.transport.intermediate import Intermediate
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
self += data
def setup_pack(n: int) -> Tuple[Intermediate, bytes, bytearray]:
input = bytes(x & 0xFF for x in range(n))
return Intermediate(), input, bytearray()
return Intermediate(), input, Output()
def test_pack_empty() -> None: