Update transports to report read length

This commit is contained in:
Lonami Exo 2023-08-30 13:43:35 +02:00
parent 269ee4f05f
commit 2be75380a3
4 changed files with 24 additions and 16 deletions

View File

@ -7,5 +7,10 @@ class Transport(ABC):
pass
@abstractmethod
def unpack(self, input: bytes, output: bytearray) -> None:
def unpack(self, input: bytes, output: bytearray) -> int:
pass
class MissingBytes(ValueError):
def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}")

View File

@ -1,6 +1,6 @@
import struct
from .abcs import Transport
from .abcs import MissingBytes, Transport
class Abridged(Transport):
@ -36,23 +36,22 @@ class Abridged(Transport):
output += struct.pack("<i", 0x7F | (length << 8))
output += input
def unpack(self, input: bytes, output: bytearray) -> None:
def unpack(self, input: bytes, output: bytearray) -> int:
if not input:
raise ValueError("missing bytes, expected: 1, got: 0")
raise MissingBytes(expected=1, got=0)
length = input[0]
if length < 127:
header_len = 1
elif len(input) < 4:
raise ValueError(f"missing bytes, expected: 4, got: {len(input)}")
raise MissingBytes(expected=4, got=len(input))
else:
header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8
length *= 4
if len(input) < header_len + length:
raise ValueError(
f"missing bytes, expected: {header_len + length}, got: {len(input)}"
)
raise MissingBytes(expected=header_len + length, got=len(input))
output += memoryview(input)[header_len : header_len + length]
return header_len + length

View File

@ -1,7 +1,7 @@
import struct
from zlib import crc32
from .abcs import Transport
from .abcs import MissingBytes, Transport
class Full(Transport):
@ -33,16 +33,17 @@ class Full(Transport):
output += struct.pack("<i", crc32(memoryview(output)[-(length - 4) :]))
self._send_seq += 1
def unpack(self, input: bytes, output: bytearray) -> None:
def unpack(self, input: bytes, output: bytearray) -> int:
if len(input) < 4:
raise ValueError(f"missing bytes, expected: 4, got: {len(input)}")
raise MissingBytes(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int)
if length < 12:
raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length:
raise ValueError(f"missing bytes, expected: {length}, got: {len(input)}")
raise MissingBytes(expected=length, got=len(input))
seq = struct.unpack_from("<i", input, 4)[0]
if seq != self._recv_seq:
@ -55,3 +56,4 @@ class Full(Transport):
self._recv_seq += 1
output += memoryview(input)[8:-4]
return length

View File

@ -1,6 +1,6 @@
import struct
from .abcs import Transport
from .abcs import MissingBytes, Transport
class Intermediate(Transport):
@ -32,12 +32,14 @@ class Intermediate(Transport):
output += struct.pack("<i", len(input))
output += input
def unpack(self, input: bytes, output: bytearray) -> None:
def unpack(self, input: bytes, output: bytearray) -> int:
if len(input) < 4:
raise ValueError(f"missing bytes, expected: {4}, got: {len(input)}")
raise MissingBytes(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int)
if len(input) < length:
raise ValueError(f"missing bytes, expected: {length}, got: {len(input)}")
raise MissingBytes(expected=length, got=len(input))
output += memoryview(input)[4 : 4 + length]
return length + 4