Port tl-gen from grammers

This commit is contained in:
Lonami Exo 2023-07-03 19:19:20 +02:00
parent fc6984d423
commit fed06f40ed
39 changed files with 2207 additions and 40 deletions

8
.gitignore vendored
View File

@ -1,4 +1,12 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
.pytest_cache/
.mypy_cache/
dist/
build/
**/tl/layer.py
**/tl/abcs/
**/tl/functions/
**/tl/types/

136
benches/bench_codegen.py Normal file
View File

@ -0,0 +1,136 @@
import datetime
import io
import struct
import timeit
from typing import Any, Iterator
from .data_codegen import DATA, Obj
ITERATIONS = 50000
def serialize_builtin(value: Any) -> bytes:
if value is None:
return b""
elif isinstance(value, bytes):
return value
elif isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, int):
return struct.pack("<i" if value < 2**32 else "<q", value)
elif isinstance(value, datetime.datetime):
return struct.pack("<i", int(value.timestamp()))
else:
raise RuntimeError(f"not a builtin type: {type(value)}")
def overhead(obj: Obj) -> None:
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
overhead(x)
else:
serialize_builtin(x)
def strategy_concat(obj: Obj) -> bytes:
res = b""
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
res += strategy_concat(x)
else:
res += serialize_builtin(x)
return res
def strategy_append(obj: Obj) -> bytes:
res = bytearray()
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
res += strategy_append(x)
else:
res += serialize_builtin(x)
return bytes(res)
def strategy_append_reuse(obj: Obj) -> bytes:
def do_append(o: Obj, res: bytearray) -> None:
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
do_append(x, res)
else:
res += serialize_builtin(x)
buffer = bytearray()
do_append(obj, buffer)
return bytes(buffer)
def strategy_join(obj: Obj) -> bytes:
return b"".join(
strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
for v in obj.__dict__.values()
for x in (v if isinstance(v, list) else [v])
)
def strategy_join_flat(obj: Obj) -> bytes:
def flatten(o: Obj) -> Iterator[bytes]:
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
yield from flatten(x)
else:
yield serialize_builtin(x)
return b"".join(flatten(obj))
def strategy_write(obj: Obj) -> bytes:
def do_write(o: Obj, buffer: io.BytesIO) -> None:
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
do_write(x, buffer)
else:
buffer.write(serialize_builtin(x))
buffer = io.BytesIO()
do_write(obj, buffer)
return buffer.getvalue()
def main() -> None:
strategies = [
v
for _, v in sorted(
((k, v) for k, v in globals().items() if k.startswith("strategy_")),
key=lambda t: t[0],
)
]
for a, b in zip(strategies[:-1], strategies[1:]):
if a(DATA) != b(DATA):
raise ValueError("strategies produce different output")
print("measuring overhead...", end="", flush=True)
overhead_duration = timeit.timeit(
"strategy(DATA)",
number=ITERATIONS,
globals={"strategy": overhead, "DATA": DATA},
)
print(f" {overhead_duration:.04f}s")
for strategy in strategies:
duration = timeit.timeit(
"strategy(DATA)",
number=ITERATIONS,
globals={"strategy": strategy, "DATA": DATA},
)
print(f"{strategy.__name__:.>30} took {duration - overhead_duration:.04f}s")
if __name__ == "__main__":
main()

58
benches/bench_truthy.py Normal file
View File

@ -0,0 +1,58 @@
import timeit
from typing import Type
ITERATIONS = 100000000
DATA = 42
def overhead(n: int) -> None:
n
def strategy_bool(n: int) -> bool:
return bool(n)
def strategy_bool_cache(n: int, _bool: Type[bool] = bool) -> bool:
return _bool(n)
def strategy_non_zero(n: int) -> bool:
return n != 0
def strategy_not_not(n: int) -> bool:
return not not n
def main() -> None:
strategies = [
v
for _, v in sorted(
((k, v) for k, v in globals().items() if k.startswith("strategy_")),
key=lambda t: t[0],
)
]
for a, b in zip(strategies[:-1], strategies[1:]):
if a(DATA) != b(DATA):
raise ValueError("strategies produce different output")
print("measuring overhead...", end="", flush=True)
overhead_duration = timeit.timeit(
"strategy(DATA)",
number=ITERATIONS,
globals={"strategy": overhead, "DATA": DATA},
)
print(f" {overhead_duration:.04f}s")
for strategy in strategies:
duration = timeit.timeit(
"strategy(DATA)",
number=ITERATIONS,
globals={"strategy": strategy, "DATA": DATA},
)
print(f"{strategy.__name__:.>30} took {duration - overhead_duration:.04f}s")
if __name__ == "__main__":
main()

1141
benches/data_codegen.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
isort .
black .
isort . --profile black --gitignore
black . --extend-exclude "tl/(abcs|functions|types)/\w+.py"
mypy --strict .
pytest .

21
generator/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016-Present LonamiWebs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

5
generator/README.md Normal file
View File

@ -0,0 +1,5 @@
# telethon_generator
Code generator for [Telethon].
[Telethon]: https://pypi.org/project/Telethon/

35
generator/pyproject.toml Normal file
View File

@ -0,0 +1,35 @@
[project]
name = "telethon_generator"
description = "Code generator for Telethon"
authors = [
{ name="Lonami", email="totufals@hotmail.com" },
]
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.8"
keywords = ["telegram", "parser", "codegen", "telethon"]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Topic :: File Formats",
"Topic :: Software Development :: Code Generators",
"Typing :: Typed",
]
dynamic = ["version"]
[project.urls]
"Homepage" = "https://telethon.dev/"
"Source" = "https://telethon.dev/code/"
"Documentation" = "https://telethon.dev/docs/"
"Bug Tracker" = "https://telethon.dev/issues/"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.dynamic]
version = {attr = "telethon_generator.__version__"}

View File

@ -0,0 +1,4 @@
from . import codegen, tl_parser
from .version import __version__
__all__ = ["codegen", "tl_parser"]

View File

@ -0,0 +1,5 @@
from .fakefs import FakeFs, SourceWriter
from .generator import generate
from .loader import ParsedTl, load_tl_file
__all__ = ["FakeFs", "SourceWriter", "generate", "ParsedTl", "load_tl_file"]

View File

@ -0,0 +1,44 @@
import weakref
from pathlib import Path
from typing import Dict
class FakeFs:
def __init__(self) -> None:
self._files: Dict[Path, bytearray] = {}
def open(self, path: Path) -> "SourceWriter":
return SourceWriter(self, path)
def write(self, path: Path, line: str) -> None:
file = self._files.get(path)
if file is None:
self._files[path] = file = bytearray()
file += line.encode("utf-8")
def materialize(self, root: Path) -> None:
for stem, data in self._files.items():
path = root / stem
path.parent.mkdir(exist_ok=True)
with path.open("wb") as fd:
fd.write(data)
def __contains__(self, path: Path) -> bool:
return path in self._files
class SourceWriter:
def __init__(self, fs: FakeFs, path: Path) -> None:
self._fs = weakref.ref(fs)
self._path = path
self._indent = ""
def write(self, string: str) -> None:
if fs := self._fs():
fs.write(self._path, f"{self._indent}{string}\n")
def indent(self, n: int = 1) -> None:
self._indent += " " * n
def dedent(self, n: int = 1) -> None:
self._indent = self._indent[: -2 * n]

View File

@ -0,0 +1,178 @@
from pathlib import Path
from typing import Set
from .fakefs import FakeFs, SourceWriter
from .loader import ParsedTl
from .serde.common import (
inner_type_fmt,
is_computed,
param_type_fmt,
to_class_name,
to_method_name,
)
from .serde.deserialization import generate_read
from .serde.serialization import generate_function, generate_write
def generate_init(writer: SourceWriter, namespaces: Set[str]) -> None:
sorted_ns = list(namespaces)
sorted_ns.sort()
if sorted_ns:
sorted_import = ", ".join(sorted_ns)
writer.write(f"from ._nons import *")
writer.write(f"from . import {sorted_import}")
sorted_all = ", ".join(f"{ns!r}" for ns in sorted_ns)
writer.write(f"__all__ = [{sorted_all}]")
def generate(fs: FakeFs, tl: ParsedTl) -> None:
generated_types = {
"True",
"Bool",
} # initial set is considered to be "compiler built-ins"
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
abc_namespaces = set()
type_namespaces = set()
function_namespaces = set()
generated_type_names = []
for typedef in tl.typedefs:
if typedef.ty.full_name not in generated_types:
if len(typedef.ty.namespace) >= 2:
raise NotImplementedError("nested abc-namespaces are not supported")
elif len(typedef.ty.namespace) == 1:
abc_namespaces.add(typedef.ty.namespace[0])
abc_path = (Path("abcs") / typedef.ty.namespace[0]).with_suffix(".py")
else:
abc_path = Path("abcs/_nons.py")
if abc_path not in fs:
fs.write(abc_path, "from abc import ABCMeta\n")
fs.write(abc_path, "from ..core.serializable import Serializable\n")
fs.write(
abc_path,
f"class {to_class_name(typedef.ty.name)}(Serializable, metaclass=ABCMeta): pass\n",
)
generated_types.add(typedef.ty.full_name)
if typedef.name in ignored_types:
continue
property_params = [p for p in typedef.params if not is_computed(p.ty)]
if len(typedef.namespace) >= 2:
raise NotImplementedError("nested type-namespaces are not supported")
elif len(typedef.namespace) == 1:
type_namespaces.add(typedef.namespace[0])
type_path = (Path("types") / typedef.namespace[0]).with_suffix(".py")
else:
type_path = Path("types/_nons.py")
writer = fs.open(type_path)
if type_path not in fs:
writer.write(f"import struct")
writer.write(f"from typing import List, Optional, Self")
writer.write(f"from .. import abcs")
writer.write(f"from ..core.reader import Reader")
writer.write(f"from ..core.serializable import serialize_bytes_to")
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
generated_type_names.append(f"{ns}{to_class_name(typedef.name)}")
# class Type(BaseType)
writer.write(
f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):"
)
# __slots__ = ('params', ...)
slots = " ".join(f"'{p.name}'," for p in property_params)
writer.write(f" __slots__ = ({slots})")
# def constructor_id()
writer.write(f" @classmethod")
writer.write(f" def constructor_id(_) -> int:")
writer.write(f" return {hex(typedef.id)}")
# def __init__()
if property_params:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s{params}) -> None:")
for p in property_params:
writer.write(f" _s.{p.name} = {p.name}")
# def _read_from()
writer.write(f" @classmethod")
writer.write(f" def _read_from(cls, reader: Reader) -> Self:")
writer.indent(2)
generate_read(writer, typedef)
params = ", ".join(f"{p.name}=_{p.name}" for p in property_params)
writer.write(f"return cls({params})")
writer.dedent(2)
# def _write_to()
writer.write(f" def _write_to(self, buffer: bytearray) -> None:")
if typedef.params:
writer.indent(2)
generate_write(writer, typedef)
writer.dedent(2)
else:
writer.write(f" pass")
for functiondef in tl.functiondefs:
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
if len(functiondef.namespace) >= 2:
raise NotImplementedError("nested function-namespaces are not supported")
elif len(functiondef.namespace) == 1:
function_namespaces.add(functiondef.namespace[0])
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(
".py"
)
else:
function_path = Path("functions/_nons.py")
writer = fs.open(function_path)
if function_path not in fs:
writer.write(f"import struct")
writer.write(f"from typing import List, Optional, Self")
writer.write(f"from .. import abcs")
writer.write(f"from ..core.request import Request")
writer.write(f"from ..core.serializable import serialize_bytes_to")
# def name(params, ...)
params = ", ".join(f"{p.name}: {param_type_fmt(p.ty)}" for p in required_params)
writer.write(f"def {to_method_name(functiondef.name)}({params}) -> Request:")
writer.indent(2)
generate_function(writer, functiondef)
writer.dedent(2)
generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces)
generate_init(fs.open(Path("types/__init__.py")), type_namespaces)
generate_init(fs.open(Path("functions/__init__.py")), function_namespaces)
generated_type_names.sort()
writer = fs.open(Path("layer.py"))
writer.write(f"from . import types")
writer.write(f"from .core import Serializable, Reader")
writer.write(f"from typing import cast, Tuple, Type")
writer.write(f"LAYER = {tl.layer!r}")
writer.write(
"TYPE_MAPPING = {t.constructor_id(): t for t in cast(Tuple[Type[Serializable]], ("
)
for name in generated_type_names:
writer.write(f" types.{name},")
writer.write("))}")
writer.write(
"Reader._get_ty = TYPE_MAPPING.get # type: ignore [method-assign, assignment]"
)
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")

View File

@ -0,0 +1,39 @@
import re
from dataclasses import dataclass
from typing import List, Optional
from ...tl_parser import Definition, FunctionDef, TypeDef, parse_tl_file
@dataclass
class ParsedTl:
layer: Optional[int]
typedefs: List[Definition]
functiondefs: List[Definition]
def load_tl_file(path: str) -> ParsedTl:
typedefs, functiondefs = [], []
with open(path, "r", encoding="utf-8") as fd:
contents = fd.read()
if m := re.search(r"//\s*LAYER\s+(\d+)", contents):
layer = int(m[1])
else:
layer = None
for definition in parse_tl_file(contents):
if isinstance(definition, Exception):
# generic types (such as vector) is known to not be implemented
if definition.args[0] != "not implemented":
raise
elif isinstance(definition, TypeDef):
typedefs.append(definition)
elif isinstance(definition, FunctionDef):
functiondefs.append(definition)
else:
raise TypeError(f"unexpected type: {type(definition)}")
return ParsedTl(
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
)

View File

@ -0,0 +1,102 @@
import re
from typing import Iterator
from ....tl_parser import BaseParameter, FlagsParameter, NormalParameter, Type
def to_class_name(name: str) -> str:
return re.sub(r"(?:^|_)([a-z])", lambda m: m[1].upper(), name)
def to_method_name(name: str) -> str:
snake_case = re.sub(
r"_+[A-Za-z]+|[A-Z]*[a-z]+", lambda m: "_" + m[0].replace("_", "").lower(), name
)
return snake_case.strip("_")
def gen_tmp_names() -> Iterator[str]:
i = 0
while True:
yield f"_t{i}"
i += 1
def is_computed(ty: BaseParameter) -> bool:
return isinstance(ty, FlagsParameter)
def is_trivial(ty: BaseParameter) -> bool:
return (
isinstance(ty, FlagsParameter)
or isinstance(ty, NormalParameter)
and not ty.flag
and ty.ty.name in ("int", "long", "double", "Bool")
)
_TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"}
def trivial_struct_fmt(ty: BaseParameter) -> str:
try:
return (
_TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
)
except KeyError:
raise ValueError("input param was not trivial")
_INNER_TYPE_MAP = {
"Bool": "bool",
"true": "bool",
"int": "int",
"long": "int",
"int128": "int",
"int256": "int",
"double": "float",
"bytes": "bytes",
"string": "str",
}
def inner_type_fmt(ty: Type) -> str:
builtin_ty = _INNER_TYPE_MAP.get(ty.name)
if builtin_ty:
return builtin_ty
elif ty.bare:
return to_class_name(ty.name)
elif ty.generic_ref:
return "bytes"
else:
ns = (".".join(ty.namespace) + ".") if ty.namespace else ""
return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter) -> str:
if isinstance(ty, FlagsParameter):
return "int"
elif not isinstance(ty, NormalParameter):
raise TypeError("unexpected input type {ty}")
inner_ty: Type
if ty.ty.generic_arg:
if ty.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg type for non-vectors not implemented"
)
inner_ty = ty.ty.generic_arg
else:
inner_ty = ty.ty
res = inner_type_fmt(inner_ty)
if ty.ty.generic_arg:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true":
res = f"Optional[{res}]"
return res

View File

@ -0,0 +1,100 @@
import struct
from itertools import groupby
from typing import Optional, Tuple
from ....tl_parser import Definition, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import inner_type_fmt, is_trivial, to_class_name, trivial_struct_fmt
def reader_read_fmt(ty: Type) -> Tuple[str, Optional[str]]:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
size = struct.calcsize(f"<{fmt}")
return f"reader.read_fmt(f'<{fmt}', {size})[0]", None
elif ty.name == "string":
return f"str(reader.read_bytes(), 'utf-8', 'replace')", None
elif ty.name == "bytes":
return f"reader.read_bytes()", None
elif ty.name == "int128":
return f"int.from_bytes(reader.read(16), 'little', signed=True)", None
elif ty.name == "int256":
return f"int.from_bytes(reader.read(32), 'little', signed=True)", None
elif ty.bare:
return f"{to_class_name(ty.name)}._read_from(reader)", None
else:
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
def generate_normal_param_read(
writer: SourceWriter, name: str, param: NormalParameter
) -> None:
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
if param.ty.name == "true":
if not flag_check:
raise NotImplementedError("true parameter is expected to be a flag")
writer.write(f"_{name} = ({flag_check}) != 0")
elif param.ty.generic_ref:
raise NotImplementedError("generic_ref deserialization not implemented")
else:
if flag_check:
writer.write(f"if {flag_check}:")
writer.indent()
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg deserialization for non-vectors not implemented"
)
if param.ty.bare:
writer.write(f"__len = reader.read_fmt('<i', 4)[0]")
writer.write(f"assert __len >= 0")
else:
writer.write(f"__vid, __len = reader.read_fmt('<ii', 8)")
writer.write(f"assert __vid == 0x1cb5c415 and __len >= 0")
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
if is_trivial(generic):
fmt = trivial_struct_fmt(generic)
size = struct.calcsize(f"<{fmt}")
writer.write(
f"_{name} = reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})[0]"
)
if param.ty.generic_arg.name == "Bool":
writer.write(
f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})"
)
writer.write(f"_{name} = [_{name} == 0x997275b5]")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = {fmt_read}{comment}")
if flag_check:
writer.dedent()
writer.write(f"else:")
writer.write(f" _{name} = None")
def generate_read(writer: SourceWriter, defn: Definition) -> None:
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.unpack can handle more than one element at a time.
group = list(iter)
names = "".join(f"_{param.name}, " for param in group)
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
size = struct.calcsize(f"<{fmt}")
writer.write(f"{names}= reader.read_fmt('<{fmt}', {size})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_read(writer, param.name, param.ty)

View File

@ -0,0 +1,161 @@
from itertools import groupby
from typing import Iterator
from ....tl_parser import Definition, FlagsParameter, NormalParameter, Parameter, Type
from ..fakefs import SourceWriter
from .common import gen_tmp_names, is_computed, is_trivial, trivial_struct_fmt
def param_value_expr(param: Parameter) -> str:
is_bool = isinstance(param.ty, NormalParameter) and param.ty.ty.name == "Bool"
pre = "0x997275b5 if " if is_bool else ""
mid = f"_{param.name}" if is_computed(param.ty) else f"self.{param.name}"
suf = " else 0xbc799737" if is_bool else ""
return f"{pre}{mid}{suf}"
def generate_buffer_append(
writer: SourceWriter, buffer: str, name: str, ty: Type
) -> None:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
if ty.name == "Bool":
writer.write(
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
)
else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
elif ty.generic_ref:
writer.write(f"{buffer} += {name}") # assume previously-serialized
elif ty.name == "string":
writer.write(f"serialize_bytes_to({buffer}, {name}.encode('utf-8'))")
elif ty.name == "bytes":
writer.write(f"serialize_bytes_to({buffer}, {name})")
elif ty.name == "int128":
writer.write(f"{buffer} += {name}.to_bytes(16, 'little', signed=True)")
elif ty.name == "int256":
writer.write(f"{buffer} += {name}.to_bytes(32, 'little', signed=True)")
elif ty.bare:
writer.write(f"{name}._write_to({buffer})")
else:
writer.write(f"{name}._write_boxed_to({buffer})")
def generate_normal_param_write(
writer: SourceWriter,
tmp_names: Iterator[str],
buffer: str,
name: str,
param: NormalParameter,
) -> None:
if param.ty.name == "true":
return # special-cased "built-in"
if param.flag:
writer.write(f"if {name} is not None:")
writer.indent()
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise NotImplementedError(
"generic_arg deserialization for non-vectors not implemented"
)
if param.ty.bare:
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
else:
writer.write(f"{buffer} += struct.pack('<ii', 0x1cb5c415, len({name}))")
generic = NormalParameter(ty=param.ty.generic_arg, flag=None)
if is_trivial(generic):
fmt = trivial_struct_fmt(generic)
if param.ty.generic_arg.name == "Bool":
tmp = next(tmp_names)
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
)
else:
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
)
else:
tmp = next(tmp_names)
writer.write(f"for {tmp} in {name}:")
writer.indent()
generate_buffer_append(writer, buffer, tmp, param.ty.generic_arg)
writer.dedent()
else:
generate_buffer_append(writer, buffer, f"{name}", param.ty)
if param.flag:
writer.dedent()
def generate_write(writer: SourceWriter, defn: Definition) -> None:
tmp_names = gen_tmp_names()
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.pack can handle more than one element at a time.
group = list(iter)
for param in group:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
f"({1 << p.ty.flag.index} if self.{p.name} else 0)"
if p.ty.ty.name == "true"
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
for p in defn.params
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"_{param.name} = {flags or 0}")
names = ", ".join(map(param_value_expr, group))
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
writer.write(f"buffer += struct.pack('<{fmt}', {names})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
)
def generate_function(writer: SourceWriter, defn: Definition) -> None:
tmp_names = gen_tmp_names()
writer.write("_buffer = bytearray()")
for trivial, iter in groupby(
defn.params,
key=lambda p: is_trivial(p.ty),
):
if trivial:
# As an optimization, struct.pack can handle more than one element at a time.
group = list(iter)
for param in group:
if isinstance(param.ty, FlagsParameter):
flags = " | ".join(
f"({1 << p.ty.flag.index} if {p.name} else 0)"
if p.ty.ty.name == "true"
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
for p in defn.params
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"{param.name} = {flags or 0}")
names = ", ".join(p.name for p in group)
fmt = "".join(trivial_struct_fmt(param.ty) for param in group)
writer.write(f"_buffer += struct.pack('<{fmt}', {names})")
else:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(
writer, tmp_names, "_buffer", param.name, param.ty
)
writer.write("return Request(b'' + _buffer)")

View File

@ -37,6 +37,11 @@ class Type:
generic_arg=generic_arg,
)
@property
def full_name(self) -> str:
ns = ".".join(self.namespace) + "." if self.namespace else ""
return f"{ns}{self.name}"
def __str__(self) -> str:
res = ""
for ns in self.namespace:

View File

@ -0,0 +1,3 @@
from .._impl.codegen import FakeFs, ParsedTl, generate
__all__ = ["FakeFs", "ParsedTl", "generate"]

View File

@ -0,0 +1,31 @@
import sys
from pathlib import Path
from .._impl.codegen import FakeFs, generate, load_tl_file
HELP = f"""
USAGE:
python -m {__package__} <TL_FILE> <OUT_DIR>
ARGS:
<TL_FILE>
The path to the `.tl' file to generate Python code from.
<OUT_DIR>
The directory where the generated code will be written to.
""".strip()
def main() -> None:
if len(sys.argv) != 3:
print(HELP)
sys.exit(1)
tl, out = sys.argv[1:]
fs = FakeFs()
generate(fs, load_tl_file(tl))
fs.materialize(Path(out))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,25 @@
from .._impl.tl_parser.tl.definition import Definition
from .._impl.tl_parser.tl.flag import Flag
from .._impl.tl_parser.tl.parameter import Parameter, TypeDefNotImplemented
from .._impl.tl_parser.tl.parameter_type import (
BaseParameter,
FlagsParameter,
NormalParameter,
)
from .._impl.tl_parser.tl.ty import Type
from .._impl.tl_parser.tl_iterator import FunctionDef, TypeDef
from .._impl.tl_parser.tl_iterator import iterate as parse_tl_file
__all__ = [
"Definition",
"Flag",
"Parameter",
"TypeDefNotImplemented",
"BaseParameter",
"FlagsParameter",
"NormalParameter",
"Type",
"FunctionDef",
"TypeDef",
"parse_tl_file",
]

View File

@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@ -1,10 +1,12 @@
from pytest import mark, raises
from .definition import Definition
from .flag import Flag
from .parameter import Parameter
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
from telethon_generator.tl_parser import (
Definition,
Flag,
FlagsParameter,
NormalParameter,
Parameter,
Type,
)
def test_parse_empty_def() -> None:

View File

@ -0,0 +1,82 @@
from typing import List, Optional
from telethon_generator.codegen import FakeFs, ParsedTl, generate
from telethon_generator.tl_parser import Definition, parse_tl_file
def get_definitions(contents: str) -> List[Definition]:
return [defn for defn in parse_tl_file(contents) if not isinstance(defn, Exception)]
def gen_py_code(
*,
typedefs: Optional[List[Definition]] = None,
functiondefs: Optional[List[Definition]] = None,
) -> str:
fs = FakeFs()
generate(
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
)
generated = bytearray()
for path, data in fs._files.items():
if path.stem not in ("__init__", "layer"):
generated += f"# {path}\n".encode("utf-8")
generated += data
data += b"\n"
return str(generated, "utf-8")
def test_generic_functions_use_bytes_parameters() -> None:
definitions = get_definitions(
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
)
result = gen_py_code(functiondefs=definitions)
assert "invoke_with_layer" in result
assert "query: bytes" in result
assert "buffer += query" in result
def test_recursive_direct() -> None:
definitions = get_definitions("textBold#6724abc4 text:RichText = RichText;")
result = gen_py_code(typedefs=definitions)
assert "text: abcs.RichText" in result
assert "read_serializable" in result
assert "write_boxed_to" in result
def test_recursive_indirect() -> None:
definitions = get_definitions(
"""
messageExtendedMedia#ee479c64 media:MessageMedia = MessageExtendedMedia;
messageMediaInvoice#f6a548d3 flags:# extended_media:flags.4?MessageExtendedMedia = MessageMedia;
"""
)
result = gen_py_code(typedefs=definitions)
assert "media: abcs.MessageMedia" in result
assert "extended_media: Optional[abcs.MessageExtendedMedia])" in result
assert "write_boxed_to" in result
assert "._write_to" not in result
assert "read_serializable" in result
def test_recursive_no_hang() -> None:
definitions = get_definitions(
"""
inputUserFromMessage#1da448e2 peer:InputPeer msg_id:int user_id:long = InputUser;
inputPeerUserFromMessage#a87b0a1c peer:InputPeer msg_id:int user_id:long = InputPeer;
"""
)
gen_py_code(typedefs=definitions)
def test_recursive_vec() -> None:
definitions = get_definitions(
"""
jsonObjectValue#c0de1bd9 key:string value:JSONValue = JSONObjectValue;
jsonArray#f7444763 value:Vector<JSONValue> = JSONValue;
jsonObject#99c1d49d value:Vector<JSONObjectValue> = JSONValue;
"""
)
result = gen_py_code(typedefs=definitions)
assert "value: List[abcs.JSONObjectValue]" in result

View File

@ -1,9 +1,12 @@
from pytest import mark, raises
from .flag import Flag
from .parameter import Parameter, TypeDefNotImplemented
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
from telethon_generator.tl_parser import (
Flag,
FlagsParameter,
NormalParameter,
Parameter,
Type,
TypeDefNotImplemented,
)
@mark.parametrize("param", [":noname", "notype:", ":"])

View File

@ -1,18 +1,17 @@
from pytest import raises
from .tl_iterator import FunctionDef, TypeDef, iterate
from telethon_generator.tl_parser import FunctionDef, TypeDef, parse_tl_file
def test_parse_bad_separator() -> None:
with raises(ValueError) as e:
for _ in iterate("---foo---"):
for _ in parse_tl_file("---foo---"):
pass
e.match("bad separator")
def test_parse_file() -> None:
items = list(
iterate(
parse_tl_file(
"""
// leading; comment
first#1 = t; // inline comment

View File

@ -1,8 +1,7 @@
from typing import Optional
from pytest import mark, raises
from .ty import Type
from telethon_generator.tl_parser import Type
def test_empty_simple() -> None:

View File

@ -1,4 +1,4 @@
from .utils import infer_id, remove_tl_comments
from telethon_generator._impl.tl_parser.utils import infer_id, remove_tl_comments
def test_remove_comments_noop() -> None:

View File

@ -1,20 +0,0 @@
from ._impl.tl.definition import Definition
from ._impl.tl.flag import Flag
from ._impl.tl.parameter import Parameter, TypeDefNotImplemented
from ._impl.tl.parameter_type import FlagsParameter, NormalParameter
from ._impl.tl.ty import Type
from ._impl.tl_iterator import FunctionDef, TypeDef
from ._impl.tl_iterator import iterate as parse_tl_file
__all__ = [
"Definition",
"Flag",
"Parameter",
"TypeDefNotImplemented",
"FlagsParameter",
"NormalParameter",
"Type",
"FunctionDef",
"TypeDef",
"parse_tl_file",
]