From e12845c38bd5a79d857886240051bd35d19b97f0 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 30 Aug 2023 13:44:42 +0200 Subject: [PATCH] Fix deserialization of bools --- .../telethon_generator/_impl/codegen/generator.py | 4 ++-- .../_impl/codegen/serde/deserialization.py | 10 ++++++++++ generator/tests/generator_test.py | 12 ++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/generator/src/telethon_generator/_impl/codegen/generator.py b/generator/src/telethon_generator/_impl/codegen/generator.py index b68b488e..84d2cf8e 100644 --- a/generator/src/telethon_generator/_impl/codegen/generator.py +++ b/generator/src/telethon_generator/_impl/codegen/generator.py @@ -10,7 +10,7 @@ from .serde.common import ( to_class_name, to_method_name, ) -from .serde.deserialization import generate_read +from .serde.deserialization import generate_read, param_value_fmt from .serde.serialization import generate_function, generate_write @@ -125,7 +125,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: writer.write(f" def _read_from(cls, reader: Reader) -> Self:") writer.indent(2) generate_read(writer, typedef) - params = ", ".join(f"{p.name}=_{p.name}" for p in property_params) + params = ", ".join(f"{p.name}={param_value_fmt(p)}" for p in property_params) writer.write(f"return cls({params})") writer.dedent(2) diff --git a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py index 8be53a0b..3c0c7b1d 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py @@ -107,8 +107,18 @@ def generate_read(writer: SourceWriter, defn: Definition) -> None: fmt = "".join(trivial_struct_fmt(param.ty) for param in group) size = struct.calcsize(f"<{fmt}") writer.write(f"{names}= reader.read_fmt('<{fmt}', {size})") + for param in group: + if isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool": + writer.write(f"assert _{param.name} in (0x997275b5, 0xbc799737)") else: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") generate_normal_param_read(writer, param.name, param.ty, defn.id) + + +def param_value_fmt(param: Parameter) -> str: + if isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool": + return f"_{param.name} == 0x997275b5" + else: + return f"_{param.name}" diff --git a/generator/tests/generator_test.py b/generator/tests/generator_test.py index b4dbd32c..680244ac 100644 --- a/generator/tests/generator_test.py +++ b/generator/tests/generator_test.py @@ -100,3 +100,15 @@ def test_object_blob_with_prefix_special_case() -> None: ) result = gen_py_code(typedefs=definitions) assert "reader.read(_bytes)" in result + + +def test_bool_mapped_from_int() -> None: + definitions = get_definitions( + """ + contact#145ade0b user_id:long mutual:Bool = Contact; + """ + ) + result = gen_py_code(typedefs=definitions) + assert "_mutual in (0x997275b5, 0xbc799737)" in result + assert "=_mutual == 0x997275b5" in result + assert "0x997275b5 if self.mutual else 0xbc799737" in result