Improve codegen

Avoid on-import modification of classes.
This makes it possible to have multiple namespaces work together.

Implement equality on all generated types.
This enables support in tests as well feeling similar to dataclasses.

Make generated code constructors keyword-only.
This increases readability and reduces risk of breakage during upgrades.
This commit is contained in:
Lonami Exo 2023-07-08 12:15:11 +02:00
parent 7b707cfc6c
commit e74332de75
6 changed files with 67 additions and 12 deletions

View File

@ -1,5 +1,13 @@
Code generation:
```sh
pip install -e generator/
python -m telethon_generator.codegen api.tl telethon/src/_impl/tl
python -m telethon_generator.codegen mtproto.tl telethon/src/_impl/tl/mtproto
python -m telethon_generator.codegen api.tl client/src/telethon/_impl/tl
python -m telethon_generator.codegen mtproto.tl client/src/telethon/_impl/tl/mtproto
```
Formatting, type-checking and testing:
```
./check.sh
```

View File

@ -1,5 +1,5 @@
import struct
from typing import TYPE_CHECKING, Any, Type, TypeVar
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
if TYPE_CHECKING:
from .serializable import Serializable
@ -8,6 +8,26 @@ if TYPE_CHECKING:
T = TypeVar("T", bound="Serializable")
def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
# Lazy import because generate code depends on the Reader.
# After the first call, the class method is replaced with direct access.
if Reader._get_ty is _bootstrap_get_ty:
from ..layer import TYPE_MAPPING as API_TYPES
from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES
if API_TYPES.keys() & MTPROTO_TYPES.keys():
raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_TYPES = API_TYPES | MTPROTO_TYPES
# Signatures don't fully match, but this is a private method
# and all previous uses are compatible with `dict.get`.
Reader._get_ty = ALL_TYPES.get # type: ignore [assignment]
return Reader._get_ty(constructor_id)
class Reader:
__slots__ = ("_buffer", "_pos", "_view")
@ -44,11 +64,7 @@ class Reader:
return data
@staticmethod
def _get_ty(_: int) -> Type["Serializable"]:
# Implementation replaced during import to prevent cycles,
# without the performance hit of having the import inside.
raise NotImplementedError
_get_ty = staticmethod(_bootstrap_get_ty)
def read_serializable(self, cls: Type[T]) -> T:
# Calls to this method likely need to ignore "type-abstract".

View File

@ -38,6 +38,13 @@ class Serializable(abc.ABC):
attrs = ", ".join(repr(getattr(self, attr)) for attr in self.__slots__)
return f"{self.__class__.__name__}({attrs})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return all(
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
)
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None:
length = len(data)

View File

@ -0,0 +1,4 @@
from . import abcs, functions, types
from .layer import TYPE_MAPPING
__all__ = ["abcs", "functions", "types", "TYPE_MAPPING"]

View File

@ -1,5 +1,10 @@
import struct
from pytest import mark
from telethon._impl.tl.core import Reader
from telethon._impl.tl.core.serializable import Serializable
from telethon._impl.tl.mtproto.types import BadServerSalt
from telethon._impl.tl.types import GeoPoint
@mark.parametrize(
@ -24,3 +29,21 @@ sentence made it past!",
def test_string(string: str, prefix: bytes, suffix: bytes) -> None:
data = prefix + string.encode("ascii") + suffix
assert str(Reader(data).read_bytes(), "ascii") == string
@mark.parametrize(
"obj",
[
GeoPoint(long=12.34, lat=56.78, access_hash=123123, accuracy_radius=100),
BadServerSalt(
bad_msg_id=1234,
bad_msg_seqno=5678,
error_code=9876,
new_server_salt=5432,
),
],
)
def test_generated_object(obj: Serializable) -> None:
assert bytes(obj)[:4] == struct.pack("<I", obj.constructor_id())
assert type(obj)._read_from(Reader(bytes(obj)[4:])) == obj
assert Reader(bytes(obj)).read_serializable(type(obj)) == obj

View File

@ -115,7 +115,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s{params}) -> None:")
writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params:
writer.write(f" _s.{p.name} = {p.name}")
@ -183,7 +183,4 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
for name in sorted(generated_type_names):
writer.write(f" types.{name},")
writer.write("))}")
writer.write(
"Reader._get_ty = TYPE_MAPPING.get # type: ignore [method-assign, assignment]"
)
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")