mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-22 09:26:37 +03:00
Generate deserializers for requests
This commit is contained in:
parent
2e1321b6c9
commit
9ba6e2ded6
|
@ -1,5 +1,24 @@
|
||||||
from .reader import Reader
|
from .reader import (
|
||||||
|
Reader,
|
||||||
|
deserialize_bool,
|
||||||
|
deserialize_i32_list,
|
||||||
|
deserialize_i64_list,
|
||||||
|
deserialize_identity,
|
||||||
|
list_deserializer,
|
||||||
|
single_deserializer,
|
||||||
|
)
|
||||||
from .request import Request
|
from .request import Request
|
||||||
from .serializable import Serializable, serialize_bytes_to
|
from .serializable import Serializable, serialize_bytes_to
|
||||||
|
|
||||||
__all__ = ["Reader", "Request", "Serializable", "serialize_bytes_to"]
|
__all__ = [
|
||||||
|
"Reader",
|
||||||
|
"deserialize_bool",
|
||||||
|
"deserialize_i32_list",
|
||||||
|
"deserialize_i64_list",
|
||||||
|
"deserialize_identity",
|
||||||
|
"list_deserializer",
|
||||||
|
"single_deserializer",
|
||||||
|
"Request",
|
||||||
|
"Serializable",
|
||||||
|
"serialize_bytes_to",
|
||||||
|
]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
import functools
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .serializable import Serializable
|
from .serializable import Serializable
|
||||||
|
@ -85,3 +86,47 @@ class Reader:
|
||||||
raise ValueError(f"No type found for constructor ID: {cid:x}")
|
raise ValueError(f"No type found for constructor ID: {cid:x}")
|
||||||
assert issubclass(ty, cls)
|
assert issubclass(ty, cls)
|
||||||
return ty._read_from(self)
|
return ty._read_from(self)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def single_deserializer(cls: Type[T]) -> Callable[[bytes], T]:
|
||||||
|
def deserializer(body: bytes) -> T:
|
||||||
|
return Reader(body).read_serializable(cls)
|
||||||
|
|
||||||
|
return deserializer
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def list_deserializer(cls: Type[T]) -> Callable[[bytes], List[T]]:
|
||||||
|
def deserializer(body: bytes) -> List[T]:
|
||||||
|
reader = Reader(body)
|
||||||
|
vec_id, length = reader.read_fmt("<ii", 8)
|
||||||
|
assert vec_id == 0x1CB5C415 and length >= 0
|
||||||
|
return [reader.read_serializable(cls) for _ in range(length)]
|
||||||
|
|
||||||
|
return deserializer
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_i64_list(body: bytes) -> List[int]:
|
||||||
|
reader = Reader(body)
|
||||||
|
vec_id, length = reader.read_fmt("<ii", 8)
|
||||||
|
assert vec_id == 0x1CB5C415 and length >= 0
|
||||||
|
return [*reader.read_fmt(f"<{length}q", length * 8)]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_i32_list(body: bytes) -> List[int]:
|
||||||
|
reader = Reader(body)
|
||||||
|
vec_id, length = reader.read_fmt("<ii", 8)
|
||||||
|
assert vec_id == 0x1CB5C415 and length >= 0
|
||||||
|
return [*reader.read_fmt(f"<{length}i", length * 4)]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_identity(body: bytes) -> bytes:
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_bool(body: bytes) -> bool:
|
||||||
|
reader = Reader(body)
|
||||||
|
bool_id = reader.read_fmt("<I", 4)[0]
|
||||||
|
assert isinstance(bool_id, int) and bool_id in (0x997275B5, 0xBC799737)
|
||||||
|
return bool_id == 0x997275B5
|
||||||
|
|
|
@ -1,9 +1,28 @@
|
||||||
import struct
|
import struct
|
||||||
from typing import Generic, TypeVar
|
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
Return = TypeVar("Return")
|
Return = TypeVar("Return")
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrap_get_deserializer(
|
||||||
|
constructor_id: int,
|
||||||
|
) -> Optional[Callable[[bytes], Any]]:
|
||||||
|
# Similar to Reader's bootstrapping.
|
||||||
|
if Request._get_deserializer is _bootstrap_get_deserializer:
|
||||||
|
from ..layer import RESPONSE_MAPPING as API_DESER
|
||||||
|
from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER
|
||||||
|
|
||||||
|
if API_DESER.keys() & MTPROTO_DESER.keys():
|
||||||
|
raise RuntimeError(
|
||||||
|
"generated api and mtproto schemas cannot have colliding constructor identifiers"
|
||||||
|
)
|
||||||
|
ALL_DESER = API_DESER | MTPROTO_DESER
|
||||||
|
|
||||||
|
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment]
|
||||||
|
|
||||||
|
return Request._get_deserializer(constructor_id)
|
||||||
|
|
||||||
|
|
||||||
class Request(bytes, Generic[Return]):
|
class Request(bytes, Generic[Return]):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
|
@ -16,5 +35,12 @@ class Request(bytes, Generic[Return]):
|
||||||
except struct.error:
|
except struct.error:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
_get_deserializer = staticmethod(_bootstrap_get_deserializer)
|
||||||
|
|
||||||
|
def deserialize_response(self, response: bytes) -> Return:
|
||||||
|
deserializer = self._get_deserializer(self.constructor_id)
|
||||||
|
assert deserializer is not None
|
||||||
|
return deserializer(response) # type: ignore [no-any-return]
|
||||||
|
|
||||||
def debug_name(self) -> str:
|
def debug_name(self) -> str:
|
||||||
return f"request#{self.constructor_id:x}"
|
return f"request#{self.constructor_id:x}"
|
||||||
|
|
|
@ -11,7 +11,11 @@ from .serde.common import (
|
||||||
to_class_name,
|
to_class_name,
|
||||||
to_method_name,
|
to_method_name,
|
||||||
)
|
)
|
||||||
from .serde.deserialization import generate_read, param_value_fmt
|
from .serde.deserialization import (
|
||||||
|
function_deserializer_fmt,
|
||||||
|
generate_read,
|
||||||
|
param_value_fmt,
|
||||||
|
)
|
||||||
from .serde.serialization import generate_function, generate_write
|
from .serde.serialization import generate_function, generate_write
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,8 +182,10 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
writer = fs.open(Path("layer.py"))
|
writer = fs.open(Path("layer.py"))
|
||||||
writer.write(f"from . import types")
|
writer.write(f"from . import abcs, types")
|
||||||
writer.write(f"from .core import Serializable, Reader")
|
writer.write(
|
||||||
|
f"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"
|
||||||
|
)
|
||||||
writer.write(f"from typing import cast, Tuple, Type")
|
writer.write(f"from typing import cast, Tuple, Type")
|
||||||
writer.write(f"LAYER = {tl.layer!r}")
|
writer.write(f"LAYER = {tl.layer!r}")
|
||||||
writer.write(
|
writer.write(
|
||||||
|
@ -188,4 +194,10 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
for name in sorted(generated_type_names):
|
for name in sorted(generated_type_names):
|
||||||
writer.write(f" types.{name},")
|
writer.write(f" types.{name},")
|
||||||
writer.write("))}")
|
writer.write("))}")
|
||||||
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")
|
writer.write("RESPONSE_MAPPING = {")
|
||||||
|
for functiondef in tl.functiondefs:
|
||||||
|
writer.write(
|
||||||
|
f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},"
|
||||||
|
)
|
||||||
|
writer.write("}")
|
||||||
|
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']")
|
||||||
|
|
|
@ -122,3 +122,43 @@ def param_value_fmt(param: Parameter) -> str:
|
||||||
return f"_{param.name} == 0x997275b5"
|
return f"_{param.name} == 0x997275b5"
|
||||||
else:
|
else:
|
||||||
return f"_{param.name}"
|
return f"_{param.name}"
|
||||||
|
|
||||||
|
|
||||||
|
def function_deserializer_fmt(defn: Definition) -> str:
|
||||||
|
if defn.ty.generic_arg:
|
||||||
|
if defn.ty.name != ("Vector"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"generic_arg return for non-boxed-vectors not implemented"
|
||||||
|
)
|
||||||
|
elif defn.ty.generic_ref:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return for generic refs inside vector not implemented"
|
||||||
|
)
|
||||||
|
elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)):
|
||||||
|
if defn.ty.generic_arg.name == "int":
|
||||||
|
return "deserialize_i32_list"
|
||||||
|
elif defn.ty.generic_arg.name == "long":
|
||||||
|
return "deserialize_i64_list"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"return for trivial arg {defn.ty.generic_arg} not implemented"
|
||||||
|
)
|
||||||
|
elif defn.ty.generic_arg.bare:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return for non-boxed serializables inside a vector not implemented"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})"
|
||||||
|
elif defn.ty.generic_ref:
|
||||||
|
return "deserialize_identity"
|
||||||
|
elif is_trivial(NormalParameter(ty=defn.ty, flag=None)):
|
||||||
|
if defn.ty.name == "Bool":
|
||||||
|
return "deserialize_bool"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"return for trivial arg {defn.ty} not implemented"
|
||||||
|
)
|
||||||
|
elif defn.ty.bare:
|
||||||
|
raise NotImplementedError("return for non-boxed serializables not implemented")
|
||||||
|
else:
|
||||||
|
return f"single_deserializer({inner_type_fmt(defn.ty)})"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user