From fcf8a7d02190d4f7976b60d255dd066310df358b Mon Sep 17 00:00:00 2001 From: Alice Kits Date: Fri, 17 Oct 2025 11:27:08 +0500 Subject: [PATCH] Add `sanitize_name` to handle reserved keywords in generated code --- .../_impl/codegen/generator.py | 12 ++-- .../_impl/codegen/serde/common.py | 7 +++ .../_impl/codegen/serde/serialization.py | 62 +++++++++---------- generator/tests/common_test.py | 19 ++++++ generator/tests/generator_test.py | 20 ++++-- 5 files changed, 76 insertions(+), 44 deletions(-) diff --git a/generator/src/telethon_generator/_impl/codegen/generator.py b/generator/src/telethon_generator/_impl/codegen/generator.py index 713318a9..ed8d49d3 100644 --- a/generator/src/telethon_generator/_impl/codegen/generator.py +++ b/generator/src/telethon_generator/_impl/codegen/generator.py @@ -7,7 +7,7 @@ from .serde.common import ( is_computed, param_type_fmt, to_class_name, - to_method_name, + to_method_name, sanitize_name, ) from .serde.deserialization import ( function_deserializer_fmt, @@ -112,7 +112,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: ) # __slots__ = ('params', ...) - slots = " ".join(f"'{p.name}'," for p in property_params) + slots = " ".join(f"'{sanitize_name(p.name)}'," for p in property_params) writer.write(f" __slots__ = ({slots})") # def constructor_id() @@ -123,18 +123,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: # def __init__() if property_params: params = "".join( - f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params + f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in property_params ) writer.write(f" def __init__(_s, *{params}) -> None:") for p in property_params: - writer.write(f" _s.{p.name} = {p.name}") + writer.write(f" _s.{sanitize_name(p.name)} = {sanitize_name(p.name)}") # def _read_from() writer.write(" @classmethod") writer.write(" def _read_from(cls, reader: Reader) -> Self:") writer.indent(2) generate_read(writer, typedef) - params = ", ".join(f"{p.name}={param_value_fmt(p)}" for p in property_params) + params = ", ".join(f"{sanitize_name(p.name)}={param_value_fmt(p)}" for p in property_params) writer.write(f"return cls({params})") writer.dedent(2) @@ -172,7 +172,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: # def name(params, ...) required_params = [p for p in functiondef.params if not is_computed(p.ty)] - params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params) + params = "".join(f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in required_params) star = "*" if params else "" return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None)) writer.write( diff --git a/generator/src/telethon_generator/_impl/codegen/serde/common.py b/generator/src/telethon_generator/_impl/codegen/serde/common.py index 0269af53..8e71764d 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/common.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/common.py @@ -1,3 +1,4 @@ +import keyword import re from collections.abc import Iterator @@ -110,3 +111,9 @@ def param_type_fmt(ty: BaseParameter) -> str: res = f"Optional[{res}]" return res + + +def sanitize_name(name: str) -> str: + if keyword.iskeyword(name): + name += "_" + return name diff --git a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py index 792893c9..361f96a3 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py @@ -4,42 +4,42 @@ from itertools import groupby from ....tl_parser import Definition, FlagsParameter, NormalParameter, Parameter, Type from ..fakefs import SourceWriter -from .common import gen_tmp_names, is_computed, is_trivial, trivial_struct_fmt +from .common import gen_tmp_names, is_computed, is_trivial, sanitize_name, trivial_struct_fmt def param_value_expr(param: Parameter) -> str: is_bool = isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool" pre = "0x997275b5 if " if is_bool else "" - mid = f"_{param.name}" if is_computed(param.ty) else f"self.{param.name}" + mid = f"_{param.name}" if is_computed(param.ty) else f"self.{sanitize_name(param.name)}" suf = " else 0xbc799737" if is_bool else "" return f"{pre}{mid}{suf}" -def generate_buffer_append( - writer: SourceWriter, buffer: str, name: str, ty: Type -) -> None: +def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None: + sanitized_name = sanitize_name(name) + if is_trivial(NormalParameter(ty=ty, flag=None)): fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None)) if ty.name == "Bool": writer.write( - f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))" + f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {sanitized_name} else 0xbc799737))" ) else: - writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})") + writer.write(f"{buffer} += struct.pack(f'<{fmt}', {sanitized_name})") elif ty.generic_ref or ty.name == "Object": - writer.write(f"{buffer} += {name}") # assume previously-serialized + writer.write(f"{buffer} += {sanitized_name}") # assume previously-serialized elif ty.name == "string": - writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))") + writer.write(f"serialize_bytes_to({buffer}, {sanitized_name}.encode('utf-8'))") elif ty.name == "bytes": - writer.write(f"serialize_bytes_to({buffer}, {name})") + writer.write(f"serialize_bytes_to({buffer}, {sanitized_name})") elif ty.name == "int128": - writer.write(f"{buffer} += {name}.to_bytes(16)") + writer.write(f"{buffer} += {sanitized_name}.to_bytes(16)") elif ty.name == "int256": - writer.write(f"{buffer} += {name}.to_bytes(32)") + writer.write(f"{buffer} += {sanitized_name}.to_bytes(32)") elif ty.bare: - writer.write(f"{name}._write_to({buffer})") + writer.write(f"{sanitized_name}._write_to({buffer})") else: - writer.write(f"{name}._write_boxed_to({buffer})") + writer.write(f"{sanitized_name}._write_boxed_to({buffer})") def generate_normal_param_write( @@ -52,20 +52,20 @@ def generate_normal_param_write( if param.ty.name == "true": return # special-cased "built-in" + sanitized_name = sanitize_name(name) + if param.flag: - writer.write(f"if {name} is not None:") + writer.write(f"if {sanitized_name} is not None:") writer.indent() if param.ty.generic_arg: if param.ty.name not in ("Vector", "vector"): - raise ValueError( - "generic_arg deserialization for non-vectors is not supported" - ) + raise ValueError("generic_arg deserialization for non-vectors is not supported") if param.ty.bare: - writer.write(f"{buffer} += struct.pack(' None: if isinstance(param.ty, FlagsParameter): flags = " | ".join( ( - f"({1 << p.ty.flag.index} if self.{p.name} else 0)" + f"({1 << p.ty.flag.index} if self.{sanitize_name(p.name)} else 0)" if p.ty.ty.name == "true" - else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})" + else f"(0 if self.{sanitize_name(p.name)} is None else {1 << p.ty.flag.index})" ) for p in defn.params if isinstance(p.ty, NormalParameter) @@ -124,7 +124,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") generate_normal_param_write( - writer, tmp_names, "buffer", f"self.{param.name}", param.ty + writer, tmp_names, "buffer", f"self.{sanitize_name(param.name)}", param.ty ) @@ -143,9 +143,9 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: if isinstance(param.ty, FlagsParameter): flags = " | ".join( ( - f"({1 << p.ty.flag.index} if {p.name} else 0)" + f"({1 << p.ty.flag.index} if {sanitize_name(p.name)} else 0)" if p.ty.ty.name == "true" - else f"(0 if {p.name} is None else {1 << p.ty.flag.index})" + else f"(0 if {sanitize_name(p.name)} is None else {1 << p.ty.flag.index})" ) for p in defn.params if isinstance(p.ty, NormalParameter) @@ -154,14 +154,12 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: ) writer.write(f"{param.name} = {flags or 0}") - names = ", ".join(p.name for p in group) + names = ", ".join(sanitize_name(p.name) for p in group) fmt = "".join(trivial_struct_fmt(param.ty) for param in group) writer.write(f"_buffer += struct.pack('<{fmt}', {names})") else: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") - generate_normal_param_write( - writer, tmp_names, "_buffer", param.name, param.ty - ) + generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty) writer.write("return Request(b'' + _buffer)") diff --git a/generator/tests/common_test.py b/generator/tests/common_test.py index 616bf073..84a9bd1c 100644 --- a/generator/tests/common_test.py +++ b/generator/tests/common_test.py @@ -1,6 +1,7 @@ from pytest import mark from telethon_generator._impl.codegen.serde.common import ( + sanitize_name, split_words, to_class_name, to_method_name, @@ -50,3 +51,21 @@ def test_to_class_name(name: str, expected: str) -> None: ) def test_to_method_name(name: str, expected: str) -> None: assert to_method_name(name) == expected + + +@mark.parametrize( + ("name", "expected"), + [ + # Shouldn't be changed + # - not keywords + ("abc", "abc"), + # - soft keywords (https://docs.python.org/3/reference/lexical_analysis.html#soft-keywords) + ("type", "type"), + # Must be changed + # - keywords + ("from", "from_"), + ("return", "return_"), + ], +) +def test_sanitize_name(name: str, expected: str) -> None: + assert sanitize_name(name) == expected diff --git a/generator/tests/generator_test.py b/generator/tests/generator_test.py index 949a3179..eef606c0 100644 --- a/generator/tests/generator_test.py +++ b/generator/tests/generator_test.py @@ -14,9 +14,7 @@ def gen_py_code( functiondefs: Optional[list[Definition]] = None, ) -> str: fs = FakeFs() - generate( - fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []) - ) + generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])) generated = bytearray() for path, data in fs._files.items(): if path.stem not in ("__init__", "layer"): @@ -27,9 +25,7 @@ def gen_py_code( def test_generic_functions_use_bytes_parameters() -> None: - definitions = get_definitions( - "invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;" - ) + definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;") result = gen_py_code(functiondefs=definitions) assert "invoke_with_layer" in result assert "query: _bytes" in result @@ -112,3 +108,15 @@ def test_bool_mapped_from_int() -> None: assert "_mutual in (0x997275b5, 0xbc799737)" in result assert "=_mutual == 0x997275b5" in result assert "0x997275b5 if self.mutual else 0xbc799737" in result + + +def test_sanitize_keywords() -> None: + definitions = get_definitions( + """ + forwardedMessage#deadbeef from:long to:long return:Bool = Message; + """ + ) + result = gen_py_code(typedefs=definitions) + assert "from_" in result + assert "to_" not in result and "to" in result + assert "return_" in result