mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2026-01-04 16:02:48 +03:00
Add sanitize_name to handle reserved keywords in generated code
This commit is contained in:
parent
bc7a02ca7f
commit
fcf8a7d021
|
|
@ -7,7 +7,7 @@ from .serde.common import (
|
|||
is_computed,
|
||||
param_type_fmt,
|
||||
to_class_name,
|
||||
to_method_name,
|
||||
to_method_name, sanitize_name,
|
||||
)
|
||||
from .serde.deserialization import (
|
||||
function_deserializer_fmt,
|
||||
|
|
@ -112,7 +112,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
)
|
||||
|
||||
# __slots__ = ('params', ...)
|
||||
slots = " ".join(f"'{p.name}'," for p in property_params)
|
||||
slots = " ".join(f"'{sanitize_name(p.name)}'," for p in property_params)
|
||||
writer.write(f" __slots__ = ({slots})")
|
||||
|
||||
# def constructor_id()
|
||||
|
|
@ -123,18 +123,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
# def __init__()
|
||||
if property_params:
|
||||
params = "".join(
|
||||
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
|
||||
f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in property_params
|
||||
)
|
||||
writer.write(f" def __init__(_s, *{params}) -> None:")
|
||||
for p in property_params:
|
||||
writer.write(f" _s.{p.name} = {p.name}")
|
||||
writer.write(f" _s.{sanitize_name(p.name)} = {sanitize_name(p.name)}")
|
||||
|
||||
# def _read_from()
|
||||
writer.write(" @classmethod")
|
||||
writer.write(" def _read_from(cls, reader: Reader) -> Self:")
|
||||
writer.indent(2)
|
||||
generate_read(writer, typedef)
|
||||
params = ", ".join(f"{p.name}={param_value_fmt(p)}" for p in property_params)
|
||||
params = ", ".join(f"{sanitize_name(p.name)}={param_value_fmt(p)}" for p in property_params)
|
||||
writer.write(f"return cls({params})")
|
||||
writer.dedent(2)
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
|
||||
# def name(params, ...)
|
||||
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
|
||||
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
|
||||
params = "".join(f", {sanitize_name(p.name)}: {param_type_fmt(p.ty)}" for p in required_params)
|
||||
star = "*" if params else ""
|
||||
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
|
||||
writer.write(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import keyword
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
|
@ -110,3 +111,9 @@ def param_type_fmt(ty: BaseParameter) -> str:
|
|||
res = f"Optional[{res}]"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def sanitize_name(name: str) -> str:
|
||||
if keyword.iskeyword(name):
|
||||
name += "_"
|
||||
return name
|
||||
|
|
|
|||
|
|
@ -4,42 +4,42 @@ from itertools import groupby
|
|||
|
||||
from ....tl_parser import Definition, FlagsParameter, NormalParameter, Parameter, Type
|
||||
from ..fakefs import SourceWriter
|
||||
from .common import gen_tmp_names, is_computed, is_trivial, trivial_struct_fmt
|
||||
from .common import gen_tmp_names, is_computed, is_trivial, sanitize_name, trivial_struct_fmt
|
||||
|
||||
|
||||
def param_value_expr(param: Parameter) -> str:
|
||||
is_bool = isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool"
|
||||
pre = "0x997275b5 if " if is_bool else ""
|
||||
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{param.name}"
|
||||
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{sanitize_name(param.name)}"
|
||||
suf = " else 0xbc799737" if is_bool else ""
|
||||
return f"{pre}{mid}{suf}"
|
||||
|
||||
|
||||
def generate_buffer_append(
|
||||
writer: SourceWriter, buffer: str, name: str, ty: Type
|
||||
) -> None:
|
||||
def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None:
|
||||
sanitized_name = sanitize_name(name)
|
||||
|
||||
if is_trivial(NormalParameter(ty=ty, flag=None)):
|
||||
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
|
||||
if ty.name == "Bool":
|
||||
writer.write(
|
||||
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
|
||||
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {sanitized_name} else 0xbc799737))"
|
||||
)
|
||||
else:
|
||||
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
|
||||
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {sanitized_name})")
|
||||
elif ty.generic_ref or ty.name == "Object":
|
||||
writer.write(f"{buffer} += {name}") # assume previously-serialized
|
||||
writer.write(f"{buffer} += {sanitized_name}") # assume previously-serialized
|
||||
elif ty.name == "string":
|
||||
writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))")
|
||||
writer.write(f"serialize_bytes_to({buffer}, {sanitized_name}.encode('utf-8'))")
|
||||
elif ty.name == "bytes":
|
||||
writer.write(f"serialize_bytes_to({buffer}, {name})")
|
||||
writer.write(f"serialize_bytes_to({buffer}, {sanitized_name})")
|
||||
elif ty.name == "int128":
|
||||
writer.write(f"{buffer} += {name}.to_bytes(16)")
|
||||
writer.write(f"{buffer} += {sanitized_name}.to_bytes(16)")
|
||||
elif ty.name == "int256":
|
||||
writer.write(f"{buffer} += {name}.to_bytes(32)")
|
||||
writer.write(f"{buffer} += {sanitized_name}.to_bytes(32)")
|
||||
elif ty.bare:
|
||||
writer.write(f"{name}._write_to({buffer})")
|
||||
writer.write(f"{sanitized_name}._write_to({buffer})")
|
||||
else:
|
||||
writer.write(f"{name}._write_boxed_to({buffer})")
|
||||
writer.write(f"{sanitized_name}._write_boxed_to({buffer})")
|
||||
|
||||
|
||||
def generate_normal_param_write(
|
||||
|
|
@ -52,20 +52,20 @@ def generate_normal_param_write(
|
|||
if param.ty.name == "true":
|
||||
return # special-cased "built-in"
|
||||
|
||||
sanitized_name = sanitize_name(name)
|
||||
|
||||
if param.flag:
|
||||
writer.write(f"if {name} is not None:")
|
||||
writer.write(f"if {sanitized_name} is not None:")
|
||||
writer.indent()
|
||||
|
||||
if param.ty.generic_arg:
|
||||
if param.ty.name not in ("Vector", "vector"):
|
||||
raise ValueError(
|
||||
"generic_arg deserialization for non-vectors is not supported"
|
||||
)
|
||||
raise ValueError("generic_arg deserialization for non-vectors is not supported")
|
||||
|
||||
if param.ty.bare:
|
||||
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
|
||||
writer.write(f"{buffer} += struct.pack('<i', len({sanitized_name}))")
|
||||
else:
|
||||
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({name}))")
|
||||
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({sanitized_name}))")
|
||||
|
||||
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
|
||||
if is_trivial(generic):
|
||||
|
|
@ -73,15 +73,15 @@ def generate_normal_param_write(
|
|||
if param.ty.generic_arg.name == "Bool":
|
||||
tmp = next(tmp_names)
|
||||
writer.write(
|
||||
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
|
||||
f"{buffer} += struct.pack(f'<{{len({sanitized_name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {sanitized_name}))"
|
||||
)
|
||||
else:
|
||||
writer.write(
|
||||
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
|
||||
f"{buffer} += struct.pack(f'<{{len({sanitized_name})}}{fmt}', *{sanitized_name})"
|
||||
)
|
||||
else:
|
||||
tmp = next(tmp_names)
|
||||
writer.write(f"for {tmp} in {name}:")
|
||||
writer.write(f"for {tmp} in {sanitized_name}:")
|
||||
writer.indent()
|
||||
generate_buffer_append(writer, buffer, tmp, param.ty.generic_arg)
|
||||
writer.dedent()
|
||||
|
|
@ -105,9 +105,9 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
|
|||
if isinstance(param.ty, FlagsParameter):
|
||||
flags = " | ".join(
|
||||
(
|
||||
f"({1 << p.ty.flag.index} if self.{p.name} else 0)"
|
||||
f"({1 << p.ty.flag.index} if self.{sanitize_name(p.name)} else 0)"
|
||||
if p.ty.ty.name == "true"
|
||||
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
|
||||
else f"(0 if self.{sanitize_name(p.name)} is None else {1 << p.ty.flag.index})"
|
||||
)
|
||||
for p in defn.params
|
||||
if isinstance(p.ty, NormalParameter)
|
||||
|
|
@ -124,7 +124,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
|
|||
if not isinstance(param.ty, NormalParameter):
|
||||
raise RuntimeError("FlagsParameter should be considered trivial")
|
||||
generate_normal_param_write(
|
||||
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
|
||||
writer, tmp_names, "buffer", f"self.{sanitize_name(param.name)}", param.ty
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -143,9 +143,9 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
|
|||
if isinstance(param.ty, FlagsParameter):
|
||||
flags = " | ".join(
|
||||
(
|
||||
f"({1 << p.ty.flag.index} if {p.name} else 0)"
|
||||
f"({1 << p.ty.flag.index} if {sanitize_name(p.name)} else 0)"
|
||||
if p.ty.ty.name == "true"
|
||||
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
|
||||
else f"(0 if {sanitize_name(p.name)} is None else {1 << p.ty.flag.index})"
|
||||
)
|
||||
for p in defn.params
|
||||
if isinstance(p.ty, NormalParameter)
|
||||
|
|
@ -154,14 +154,12 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
|
|||
)
|
||||
writer.write(f"{param.name} = {flags or 0}")
|
||||
|
||||
names = ", ".join(p.name for p in group)
|
||||
names = ", ".join(sanitize_name(p.name) for p in group)
|
||||
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
|
||||
writer.write(f"_buffer += struct.pack('<{fmt}', {names})")
|
||||
else:
|
||||
for param in iter:
|
||||
if not isinstance(param.ty, NormalParameter):
|
||||
raise RuntimeError("FlagsParameter should be considered trivial")
|
||||
generate_normal_param_write(
|
||||
writer, tmp_names, "_buffer", param.name, param.ty
|
||||
)
|
||||
generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty)
|
||||
writer.write("return Request(b'' + _buffer)")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from pytest import mark
|
||||
|
||||
from telethon_generator._impl.codegen.serde.common import (
|
||||
sanitize_name,
|
||||
split_words,
|
||||
to_class_name,
|
||||
to_method_name,
|
||||
|
|
@ -50,3 +51,21 @@ def test_to_class_name(name: str, expected: str) -> None:
|
|||
)
|
||||
def test_to_method_name(name: str, expected: str) -> None:
|
||||
assert to_method_name(name) == expected
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
("name", "expected"),
|
||||
[
|
||||
# Shouldn't be changed
|
||||
# - not keywords
|
||||
("abc", "abc"),
|
||||
# - soft keywords (https://docs.python.org/3/reference/lexical_analysis.html#soft-keywords)
|
||||
("type", "type"),
|
||||
# Must be changed
|
||||
# - keywords
|
||||
("from", "from_"),
|
||||
("return", "return_"),
|
||||
],
|
||||
)
|
||||
def test_sanitize_name(name: str, expected: str) -> None:
|
||||
assert sanitize_name(name) == expected
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ def gen_py_code(
|
|||
functiondefs: Optional[list[Definition]] = None,
|
||||
) -> str:
|
||||
fs = FakeFs()
|
||||
generate(
|
||||
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
|
||||
)
|
||||
generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []))
|
||||
generated = bytearray()
|
||||
for path, data in fs._files.items():
|
||||
if path.stem not in ("__init__", "layer"):
|
||||
|
|
@ -27,9 +25,7 @@ def gen_py_code(
|
|||
|
||||
|
||||
def test_generic_functions_use_bytes_parameters() -> None:
|
||||
definitions = get_definitions(
|
||||
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
|
||||
)
|
||||
definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;")
|
||||
result = gen_py_code(functiondefs=definitions)
|
||||
assert "invoke_with_layer" in result
|
||||
assert "query: _bytes" in result
|
||||
|
|
@ -112,3 +108,15 @@ def test_bool_mapped_from_int() -> None:
|
|||
assert "_mutual in (0x997275b5, 0xbc799737)" in result
|
||||
assert "=_mutual == 0x997275b5" in result
|
||||
assert "0x997275b5 if self.mutual else 0xbc799737" in result
|
||||
|
||||
|
||||
def test_sanitize_keywords() -> None:
|
||||
definitions = get_definitions(
|
||||
"""
|
||||
forwardedMessage#deadbeef from:long to:long return:Bool = Message;
|
||||
"""
|
||||
)
|
||||
result = gen_py_code(typedefs=definitions)
|
||||
assert "from_" in result
|
||||
assert "to_" not in result and "to" in result
|
||||
assert "return_" in result
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user