Fix deserialization of bools

This commit is contained in:
Lonami Exo 2023-08-30 13:44:42 +02:00
parent 2be75380a3
commit e12845c38b
3 changed files with 24 additions and 2 deletions

View File

@ -10,7 +10,7 @@ from .serde.common import (
to_class_name, to_class_name,
to_method_name, to_method_name,
) )
from .serde.deserialization import generate_read from .serde.deserialization import generate_read, param_value_fmt
from .serde.serialization import generate_function, generate_write from .serde.serialization import generate_function, generate_write
@ -125,7 +125,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer.write(f" def _read_from(cls, reader: Reader) -> Self:") writer.write(f" def _read_from(cls, reader: Reader) -> Self:")
writer.indent(2) writer.indent(2)
generate_read(writer, typedef) generate_read(writer, typedef)
params = ", ".join(f"{p.name}=_{p.name}" for p in property_params) params = ", ".join(f"{p.name}={param_value_fmt(p)}" for p in property_params)
writer.write(f"return cls({params})") writer.write(f"return cls({params})")
writer.dedent(2) writer.dedent(2)

View File

@ -107,8 +107,18 @@ def generate_read(writer: SourceWriter, defn: Definition) -> None:
fmt = "".join(trivial_struct_fmt(param.ty) for param in group) fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
size = struct.calcsize(f"<{fmt}") size = struct.calcsize(f"<{fmt}")
writer.write(f"{names}= reader.read_fmt('<{fmt}', {size})") writer.write(f"{names}= reader.read_fmt('<{fmt}', {size})")
for param in group:
if isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool":
writer.write(f"assert _{param.name} in (0x997275b5, 0xbc799737)")
else: else:
for param in iter: for param in iter:
if not isinstance(param.ty, NormalParameter): if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial") raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_read(writer, param.name, param.ty, defn.id) generate_normal_param_read(writer, param.name, param.ty, defn.id)
def param_value_fmt(param: Parameter) -> str:
if isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool":
return f"_{param.name} == 0x997275b5"
else:
return f"_{param.name}"

View File

@ -100,3 +100,15 @@ def test_object_blob_with_prefix_special_case() -> None:
) )
result = gen_py_code(typedefs=definitions) result = gen_py_code(typedefs=definitions)
assert "reader.read(_bytes)" in result assert "reader.read(_bytes)" in result
def test_bool_mapped_from_int() -> None:
definitions = get_definitions(
"""
contact#145ade0b user_id:long mutual:Bool = Contact;
"""
)
result = gen_py_code(typedefs=definitions)
assert "_mutual in (0x997275b5, 0xbc799737)" in result
assert "=_mutual == 0x997275b5" in result
assert "0x997275b5 if self.mutual else 0xbc799737" in result