mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-25 19:03:46 +03:00
Improve codegen
Avoid on-import modification of classes. This makes it possible to have multiple namespaces work together. Implement equality on all generated types. This enables support in tests as well feeling similar to dataclasses. Make generated code constructors keyword-only. This increases readability and reduces risk of breakage during upgrades.
This commit is contained in:
parent
7b707cfc6c
commit
e74332de75
|
@ -1,5 +1,13 @@
|
|||
Code generation:
|
||||
|
||||
```sh
|
||||
pip install -e generator/
|
||||
python -m telethon_generator.codegen api.tl telethon/src/_impl/tl
|
||||
python -m telethon_generator.codegen mtproto.tl telethon/src/_impl/tl/mtproto
|
||||
python -m telethon_generator.codegen api.tl client/src/telethon/_impl/tl
|
||||
python -m telethon_generator.codegen mtproto.tl client/src/telethon/_impl/tl/mtproto
|
||||
```
|
||||
|
||||
Formatting, type-checking and testing:
|
||||
|
||||
```
|
||||
./check.sh
|
||||
```
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .serializable import Serializable
|
||||
|
@ -8,6 +8,26 @@ if TYPE_CHECKING:
|
|||
T = TypeVar("T", bound="Serializable")
|
||||
|
||||
|
||||
def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
|
||||
# Lazy import because generate code depends on the Reader.
|
||||
# After the first call, the class method is replaced with direct access.
|
||||
if Reader._get_ty is _bootstrap_get_ty:
|
||||
from ..layer import TYPE_MAPPING as API_TYPES
|
||||
from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES
|
||||
|
||||
if API_TYPES.keys() & MTPROTO_TYPES.keys():
|
||||
raise RuntimeError(
|
||||
"generated api and mtproto schemas cannot have colliding constructor identifiers"
|
||||
)
|
||||
ALL_TYPES = API_TYPES | MTPROTO_TYPES
|
||||
|
||||
# Signatures don't fully match, but this is a private method
|
||||
# and all previous uses are compatible with `dict.get`.
|
||||
Reader._get_ty = ALL_TYPES.get # type: ignore [assignment]
|
||||
|
||||
return Reader._get_ty(constructor_id)
|
||||
|
||||
|
||||
class Reader:
|
||||
__slots__ = ("_buffer", "_pos", "_view")
|
||||
|
||||
|
@ -44,11 +64,7 @@ class Reader:
|
|||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _get_ty(_: int) -> Type["Serializable"]:
|
||||
# Implementation replaced during import to prevent cycles,
|
||||
# without the performance hit of having the import inside.
|
||||
raise NotImplementedError
|
||||
_get_ty = staticmethod(_bootstrap_get_ty)
|
||||
|
||||
def read_serializable(self, cls: Type[T]) -> T:
|
||||
# Calls to this method likely need to ignore "type-abstract".
|
||||
|
|
|
@ -38,6 +38,13 @@ class Serializable(abc.ABC):
|
|||
attrs = ", ".join(repr(getattr(self, attr)) for attr in self.__slots__)
|
||||
return f"{self.__class__.__name__}({attrs})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return all(
|
||||
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
|
||||
)
|
||||
|
||||
|
||||
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None:
|
||||
length = len(data)
|
||||
|
|
4
client/src/telethon/_impl/tl/mtproto/__init__.py
Normal file
4
client/src/telethon/_impl/tl/mtproto/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from . import abcs, functions, types
|
||||
from .layer import TYPE_MAPPING
|
||||
|
||||
__all__ = ["abcs", "functions", "types", "TYPE_MAPPING"]
|
|
@ -1,5 +1,10 @@
|
|||
import struct
|
||||
|
||||
from pytest import mark
|
||||
from telethon._impl.tl.core import Reader
|
||||
from telethon._impl.tl.core.serializable import Serializable
|
||||
from telethon._impl.tl.mtproto.types import BadServerSalt
|
||||
from telethon._impl.tl.types import GeoPoint
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
|
@ -24,3 +29,21 @@ sentence made it past!",
|
|||
def test_string(string: str, prefix: bytes, suffix: bytes) -> None:
|
||||
data = prefix + string.encode("ascii") + suffix
|
||||
assert str(Reader(data).read_bytes(), "ascii") == string
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"obj",
|
||||
[
|
||||
GeoPoint(long=12.34, lat=56.78, access_hash=123123, accuracy_radius=100),
|
||||
BadServerSalt(
|
||||
bad_msg_id=1234,
|
||||
bad_msg_seqno=5678,
|
||||
error_code=9876,
|
||||
new_server_salt=5432,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_generated_object(obj: Serializable) -> None:
|
||||
assert bytes(obj)[:4] == struct.pack("<I", obj.constructor_id())
|
||||
assert type(obj)._read_from(Reader(bytes(obj)[4:])) == obj
|
||||
assert Reader(bytes(obj)).read_serializable(type(obj)) == obj
|
||||
|
|
|
@ -115,7 +115,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
params = "".join(
|
||||
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
|
||||
)
|
||||
writer.write(f" def __init__(_s{params}) -> None:")
|
||||
writer.write(f" def __init__(_s, *{params}) -> None:")
|
||||
for p in property_params:
|
||||
writer.write(f" _s.{p.name} = {p.name}")
|
||||
|
||||
|
@ -183,7 +183,4 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
for name in sorted(generated_type_names):
|
||||
writer.write(f" types.{name},")
|
||||
writer.write("))}")
|
||||
writer.write(
|
||||
"Reader._get_ty = TYPE_MAPPING.get # type: ignore [method-assign, assignment]"
|
||||
)
|
||||
writer.write(f"__all__ = ['LAYER', 'TYPE_MAPPING']")
|
||||
|
|
Loading…
Reference in New Issue
Block a user