Add sanitize_name to handle reserved keywords in generated code

This commit is contained in:
Alice Kits 2025-10-17 11:27:08 +05:00
parent bc7a02ca7f
commit fcf8a7d021
No known key found for this signature in database
GPG Key ID: 8D5EE63C035D3D3A
5 changed files with 76 additions and 44 deletions

View File

@ -7,7 +7,7 @@ from .serde.common import (
is_computed,
param_type_fmt,
to_class_name,
to_method_name,
to_method_name, sanitize_name,
)
from .serde.deserialization import (
function_deserializer_fmt,
@ -112,7 +112,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
)
# __slots__ = ('params', ...)
slots = " ".join(f"'{p.name}'," for p in property_params)
slots = " ".join(f"'{sanitize_name(p.name)}'," for p in property_params)
writer.write(f" __slots__ = ({slots})")
# def constructor_id()
@ -123,18 +123,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def __init__()
if property_params:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params:
writer.write(f" _s.{p.name} = {p.name}")
writer.write(f" _s.{sanitize_name(p.name)} = {sanitize_name(p.name)}")
# def _read_from()
writer.write(" @classmethod")
writer.write(" def _read_from(cls, reader: Reader) -> Self:")
writer.indent(2)
generate_read(writer, typedef)
params = ", ".join(f"{p.name}={param_value_fmt(p)}" for p in property_params)
params = ", ".join(f"{sanitize_name(p.name)}={param_value_fmt(p)}" for p in property_params)
writer.write(f"return cls({params})")
writer.dedent(2)
@ -172,7 +172,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def name(params, ...)
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
params = "".join(f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in required_params)
star = "*" if params else ""
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
writer.write(

View File

@ -1,3 +1,4 @@
import keyword
import re
from collections.abc import Iterator
@ -110,3 +111,9 @@ def param_type_fmt(ty: BaseParameter) -> str:
res = f"Optional[{res}]"
return res
def sanitize_name(name: str) -> str:
if keyword.iskeyword(name):
name += "_"
return name

View File

@ -4,42 +4,42 @@ from itertools import groupby
from ....tl_parser import Definition, FlagsParameter, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import gen_tmp_names, is_computed, is_trivial, trivial_struct_fmt
from .common import gen_tmp_names, is_computed, is_trivial, sanitize_name, trivial_struct_fmt
def param_value_expr(param: Parameter) -> str:
is_bool = isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool"
pre = "0x997275b5 if " if is_bool else ""
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{param.name}"
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{sanitize_name(param.name)}"
suf = " else 0xbc799737" if is_bool else ""
return f"{pre}{mid}{suf}"
def generate_buffer_append(
writer: SourceWriter, buffer: str, name: str, ty: Type
) -> None:
def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None:
sanitized_name = sanitize_name(name)
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
if ty.name == "Bool":
writer.write(
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {sanitized_name} else 0xbc799737))"
)
else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {sanitized_name})")
elif ty.generic_ref or ty.name == "Object":
writer.write(f"{buffer} += {name}") # assume previously-serialized
writer.write(f"{buffer} += {sanitized_name}") # assume previously-serialized
elif ty.name == "string":
writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))")
writer.write(f"serialize_bytes_to({buffer}, {sanitized_name}.encode('utf-8'))")
elif ty.name == "bytes":
writer.write(f"serialize_bytes_to({buffer}, {name})")
writer.write(f"serialize_bytes_to({buffer}, {sanitized_name})")
elif ty.name == "int128":
writer.write(f"{buffer} += {name}.to_bytes(16)")
writer.write(f"{buffer} += {sanitized_name}.to_bytes(16)")
elif ty.name == "int256":
writer.write(f"{buffer} += {name}.to_bytes(32)")
writer.write(f"{buffer} += {sanitized_name}.to_bytes(32)")
elif ty.bare:
writer.write(f"{name}._write_to({buffer})")
writer.write(f"{sanitized_name}._write_to({buffer})")
else:
writer.write(f"{name}._write_boxed_to({buffer})")
writer.write(f"{sanitized_name}._write_boxed_to({buffer})")
def generate_normal_param_write(
@ -52,20 +52,20 @@ def generate_normal_param_write(
if param.ty.name == "true":
return # special-cased "built-in"
sanitized_name = sanitize_name(name)
if param.flag:
writer.write(f"if {name} is not None:")
writer.write(f"if {sanitized_name} is not None:")
writer.indent()
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise ValueError(
"generic_arg deserialization for non-vectors is not supported"
)
raise ValueError("generic_arg deserialization for non-vectors is not supported")
if param.ty.bare:
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
writer.write(f"{buffer} += struct.pack('<i', len({sanitized_name}))")
else:
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({name}))")
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({sanitized_name}))")
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
if is_trivial(generic):
@ -73,15 +73,15 @@ def generate_normal_param_write(
if param.ty.generic_arg.name == "Bool":
tmp = next(tmp_names)
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
f"{buffer} += struct.pack(f'<{{len({sanitized_name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {sanitized_name}))"
)
else:
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
f"{buffer} += struct.pack(f'<{{len({sanitized_name})}}{fmt}', *{sanitized_name})"
)
else:
tmp = next(tmp_names)
writer.write(f"for {tmp} in {name}:")
writer.write(f"for {tmp} in {sanitized_name}:")
writer.indent()
generate_buffer_append(writer, buffer, tmp, param.ty.generic_arg)
writer.dedent()
@ -105,9 +105,9 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
(
f"({1 << p.ty.flag.index} if self.{p.name} else 0)"
f"({1 << p.ty.flag.index} if self.{sanitize_name(p.name)} else 0)"
if p.ty.ty.name == "true"
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
else f"(0 if self.{sanitize_name(p.name)} is None else {1 << p.ty.flag.index})"
)
for p in defn.params
if isinstance(p.ty, NormalParameter)
@ -124,7 +124,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
writer, tmp_names, "buffer", f"self.{sanitize_name(param.name)}", param.ty
)
@ -143,9 +143,9 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
(
f"({1 << p.ty.flag.index} if {p.name} else 0)"
f"({1 << p.ty.flag.index} if {sanitize_name(p.name)} else 0)"
if p.ty.ty.name == "true"
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
else f"(0 if {sanitize_name(p.name)} is None else {1 << p.ty.flag.index})"
)
for p in defn.params
if isinstance(p.ty, NormalParameter)
@ -154,14 +154,12 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
)
writer.write(f"{param.name} = {flags or 0}")
names = ", ".join(p.name for p in group)
names = ", ".join(sanitize_name(p.name) for p in group)
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
writer.write(f"_buffer += struct.pack('<{fmt}', {names})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "_buffer", param.name, param.ty
)
generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty)
writer.write("return Request(b'' + _buffer)")

View File

@ -1,6 +1,7 @@
from pytest import mark
from telethon_generator._impl.codegen.serde.common import (
sanitize_name,
split_words,
to_class_name,
to_method_name,
@ -50,3 +51,21 @@ def test_to_class_name(name: str, expected: str) -> None:
)
def test_to_method_name(name: str, expected: str) -> None:
assert to_method_name(name) == expected
@mark.parametrize(
("name", "expected"),
[
# Shouldn't be changed
# - not keywords
("abc", "abc"),
# - soft keywords (https://docs.python.org/3/reference/lexical_analysis.html#soft-keywords)
("type", "type"),
# Must be changed
# - keywords
("from", "from_"),
("return", "return_"),
],
)
def test_sanitize_name(name: str, expected: str) -> None:
assert sanitize_name(name) == expected

View File

@ -14,9 +14,7 @@ def gen_py_code(
functiondefs: Optional[list[Definition]] = None,
) -> str:
fs = FakeFs()
generate(
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
)
generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []))
generated = bytearray()
for path, data in fs._files.items():
if path.stem not in ("__init__", "layer"):
@ -27,9 +25,7 @@ def gen_py_code(
def test_generic_functions_use_bytes_parameters() -> None:
definitions = get_definitions(
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
)
definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;")
result = gen_py_code(functiondefs=definitions)
assert "invoke_with_layer" in result
assert "query: _bytes" in result
@ -112,3 +108,15 @@ def test_bool_mapped_from_int() -> None:
assert "_mutual in (0x997275b5, 0xbc799737)" in result
assert "=_mutual == 0x997275b5" in result
assert "0x997275b5 if self.mutual else 0xbc799737" in result
def test_sanitize_keywords() -> None:
definitions = get_definitions(
"""
forwardedMessage#deadbeef from:long to:long return:Bool = Message;
"""
)
result = gen_py_code(typedefs=definitions)
assert "from_" in result
assert "to_" not in result and "to" in result
assert "return_" in result