Port tl_parser from grammers

This commit is contained in:
Lonami Exo 2023-06-12 23:05:53 +02:00
parent fb41cc0546
commit fc6984d423
16 changed files with 822 additions and 0 deletions

4
check.sh Normal file
View File

@ -0,0 +1,4 @@
isort .
black .
mypy --strict .
pytest .

View File

@ -0,0 +1,20 @@
from ._impl.tl.definition import Definition
from ._impl.tl.flag import Flag
from ._impl.tl.parameter import Parameter, TypeDefNotImplemented
from ._impl.tl.parameter_type import FlagsParameter, NormalParameter
from ._impl.tl.ty import Type
from ._impl.tl_iterator import FunctionDef, TypeDef
from ._impl.tl_iterator import iterate as parse_tl_file
__all__ = [
"Definition",
"Flag",
"Parameter",
"TypeDefNotImplemented",
"FlagsParameter",
"NormalParameter",
"Type",
"FunctionDef",
"TypeDef",
"parse_tl_file",
]

View File

@ -0,0 +1,118 @@
from dataclasses import dataclass
from typing import List, Self, Set
from ..utils import infer_id
from .parameter import Parameter, TypeDefNotImplemented
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
@dataclass
class Definition:
namespace: List[str]
name: str
id: int
params: List[Parameter]
ty: Type
@classmethod
def from_str(cls, definition: str) -> Self:
if not definition or definition.isspace():
raise ValueError("empty")
parts = definition.split("=")
if len(parts) < 2:
raise ValueError("missing type")
left, ty_str, *_ = map(str.strip, parts)
try:
ty = Type.from_str(ty_str)
except ValueError as e:
if e.args[0] == "empty":
raise ValueError("missing type")
else:
raise
if (pos := left.find(" ")) != -1:
name, middle = left[:pos], left[pos:].strip()
else:
name, middle = left.strip(), ""
parts = name.split("#")
if len(parts) < 2:
name, id_str = parts[0], None
else:
name, id_str, *_ = parts
namespace = name.split(".")
if not all(namespace):
raise ValueError("missing name")
name = namespace.pop()
if id_str is None:
id = infer_id(definition)
else:
try:
id = int(id_str, 16)
except ValueError:
raise ValueError("invalid id")
type_defs: List[str] = []
flag_defs = []
params = []
for param_str in middle.split():
try:
param = Parameter.from_str(param_str)
except TypeDefNotImplemented as e:
type_defs.append(e.name)
continue
if isinstance(param.ty, FlagsParameter):
flag_defs.append(param.name)
elif not isinstance(param.ty, NormalParameter):
raise NotImplementedError
elif param.ty.ty.generic_ref and param.ty.ty.name not in type_defs:
raise ValueError("missing def")
elif param.ty.flag and param.ty.flag.name not in flag_defs:
raise ValueError("missing def")
params.append(param)
if ty.name in type_defs:
ty.generic_ref = True
return cls(
namespace=namespace,
name=name,
id=id,
params=params,
ty=ty,
)
@property
def full_name(self) -> str:
ns = ".".join(self.namespace) + "." if self.namespace else ""
return f"{ns}{self.name}"
def __str__(self) -> str:
res = ""
for ns in self.namespace:
res += f"{ns}."
res += f"{self.name}#{self.id:x}"
def_set: Set[str] = set()
for param in self.params:
if isinstance(param.ty, NormalParameter):
def_set.update(param.ty.ty.find_generic_refs())
type_defs = list(sorted(def_set))
for type_def in type_defs:
res += f" {{{type_def}:Type}}"
for param in self.params:
res += f" {param}"
res += f" = {self.ty}"
return res

View File

@ -0,0 +1,198 @@
from pytest import mark, raises
from .definition import Definition
from .flag import Flag
from .parameter import Parameter
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
def test_parse_empty_def() -> None:
with raises(ValueError) as e:
Definition.from_str("")
e.match("empty")
@mark.parametrize("defn", ["foo#bar = baz", "foo#? = baz", "foo# = baz"])
def test_parse_bad_id(defn: str) -> None:
with raises(ValueError) as e:
Definition.from_str(defn)
e.match("invalid id")
def test_parse_no_name() -> None:
with raises(ValueError) as e:
Definition.from_str(" = foo")
e.match("missing name")
@mark.parametrize("defn", ["foo", "foo ="])
def test_parse_no_type(defn: str) -> None:
with raises(ValueError) as e:
Definition.from_str(defn)
e.match("missing type")
def test_parse_unimplemented() -> None:
with raises(ValueError) as e:
Definition.from_str("int ? = Int")
e.match("not implemented")
@mark.parametrize(
("defn", "id"),
[
(
"rpc_answer_dropped msg_id:long seq_no:int bytes:int = RpcDropAnswer",
0xA43AD8B7,
),
(
"rpc_answer_dropped#123456 msg_id:long seq_no:int bytes:int = RpcDropAnswer",
0x123456,
),
],
)
def test_parse_override_id(defn: str, id: int) -> None:
assert Definition.from_str(defn).id == id
def test_parse_valid_definition() -> None:
defn = Definition.from_str("a#1=d")
assert defn.name == "a"
assert defn.id == 1
assert len(defn.params) == 0
assert defn.ty == Type(
namespace=[],
name="d",
bare=True,
generic_ref=False,
generic_arg=None,
)
defn = Definition.from_str("a=d<e>")
assert defn.name == "a"
assert defn.id != 0
assert len(defn.params) == 0
assert defn.ty == Type(
namespace=[],
name="d",
bare=True,
generic_ref=False,
generic_arg=Type.from_str("e"),
)
defn = Definition.from_str("a b:c = d")
assert defn.name == "a"
assert defn.id != 0
assert len(defn.params) == 1
assert defn.ty == Type(
namespace=[],
name="d",
bare=True,
generic_ref=False,
generic_arg=None,
)
defn = Definition.from_str("a#1 {b:Type} c:!b = d")
assert defn.name, "a"
assert defn.id, 1
assert len(defn.params), 1
assert isinstance(defn.params[0].ty, NormalParameter)
assert defn.params[0].ty.ty.generic_ref
assert defn.ty == Type(
namespace=[],
name="d",
bare=True,
generic_ref=False,
generic_arg=None,
)
def test_parse_multiline_definition() -> None:
defn = """
first#1 lol:param
= t;
"""
assert Definition.from_str(defn).id, 1
defn = """
second#2
lol:String
= t;
"""
assert Definition.from_str(defn).id, 2
defn = """
third#3
lol:String
=
t;
"""
assert Definition.from_str(defn).id, 3
def test_parse_complete() -> None:
defn = "ns1.name#123 {X:Type} flags:# pname:flags.10?ns2.Vector<!X> = ns3.Type"
assert Definition.from_str(defn) == Definition(
namespace=["ns1"],
name="name",
id=0x123,
params=[
Parameter(
name="flags",
ty=FlagsParameter(),
),
Parameter(
name="pname",
ty=NormalParameter(
ty=Type(
namespace=["ns2"],
name="Vector",
bare=False,
generic_ref=False,
generic_arg=Type(
namespace=[],
name="X",
bare=False,
generic_ref=True,
generic_arg=None,
),
),
flag=Flag(name="flags", index=10),
),
),
],
ty=Type(
namespace=["ns3"],
name="Type",
bare=False,
generic_ref=False,
generic_arg=None,
),
)
@mark.parametrize(
"defn",
[
"name param:!X = Type",
"name {X:Type} param:!Y = Type",
"name param:flags.0?true = Type",
"name foo:# param:flags.0?true = Type",
],
)
def test_parse_missing_def(defn: str) -> None:
with raises(ValueError) as e:
Definition.from_str(defn)
e.match("missing def")
def test_test_to_string() -> None:
defn = "ns1.name#123 {X:Type} flags:# pname:flags.10?ns2.Vector<!X> = ns3.Type"
assert str(Definition.from_str(defn)), defn

View File

@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Self
@dataclass
class Flag:
name: str
index: int
@classmethod
def from_str(cls, ty: str) -> Self:
if (dot_pos := ty.find(".")) != -1:
try:
index = int(ty[dot_pos + 1 :])
except ValueError:
raise ValueError("invalid flag")
else:
return cls(name=ty[:dot_pos], index=index)
else:
raise ValueError("invalid flag")
def __str__(self) -> str:
return f"{self.name}.{self.index}"

View File

@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Self
from .parameter_type import BaseParameter
class TypeDefNotImplemented(NotImplementedError):
def __init__(self, name: str):
super().__init__(f"typedef not implemented: {name}")
self.name = name
@dataclass
class Parameter:
name: str
ty: BaseParameter
@classmethod
def from_str(cls, param: str) -> Self:
if param.startswith("{"):
if param.endswith(":Type}"):
raise TypeDefNotImplemented(param[1 : param.index(":")])
else:
raise ValueError("missing def")
parts = param.split(":")
if not parts:
raise ValueError("empty")
elif len(parts) == 1:
raise ValueError("not implemented")
else:
name, ty, *_ = parts
if not name:
raise ValueError("empty")
return cls(name=name, ty=BaseParameter.from_str(ty))
def __str__(self) -> str:
return f"{self.name}:{self.ty}"

View File

@ -0,0 +1,97 @@
from pytest import mark, raises
from .flag import Flag
from .parameter import Parameter, TypeDefNotImplemented
from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type
@mark.parametrize("param", [":noname", "notype:", ":"])
def test_empty_param(param: str) -> None:
with raises(ValueError) as e:
Parameter.from_str(param)
e.match("empty")
@mark.parametrize("param", ["", "no colon", "colonless"])
def test_unknown_param(param: str) -> None:
with raises(ValueError) as e:
Parameter.from_str(param)
e.match("not implemented")
@mark.parametrize("param", ["foo:bar?", "foo:?bar", "foo:bar?baz", "foo:bar.baz?qux"])
def test_bad_flags(param: str) -> None:
with raises(ValueError) as e:
Parameter.from_str(param)
e.match("invalid flag")
@mark.parametrize("param", ["foo:<bar", "foo:bar<"])
def test_bad_generics(param: str) -> None:
with raises(ValueError) as e:
Parameter.from_str(param)
e.match("invalid generic")
def test_type_def_param() -> None:
with raises(TypeDefNotImplemented) as e:
Parameter.from_str("{a:Type}")
e.match("typedef not implemented: a")
def test_unknown_def_param() -> None:
with raises(ValueError) as e:
Parameter.from_str("{a:foo}")
e.match("missing def")
def test_valid_param() -> None:
assert Parameter.from_str("foo:#") == Parameter(name="foo", ty=FlagsParameter())
assert Parameter.from_str("foo:!bar") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(
namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None
),
flag=None,
),
)
assert Parameter.from_str("foo:bar.1?baz") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(
namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None
),
flag=Flag(
name="bar",
index=1,
),
),
)
assert Parameter.from_str("foo:bar<baz>") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(
namespace=[],
name="bar",
bare=True,
generic_ref=False,
generic_arg=Type.from_str("baz"),
),
flag=None,
),
)
assert Parameter.from_str("foo:bar.1?baz<qux>") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(
namespace=[],
name="baz",
bare=True,
generic_ref=False,
generic_arg=Type.from_str("qux"),
),
flag=Flag(name="bar", index=1),
),
)

View File

@ -0,0 +1,39 @@
from abc import ABC
from dataclasses import dataclass
from typing import Optional, Union
from .flag import Flag
from .ty import Type
class BaseParameter(ABC):
@staticmethod
def from_str(ty: str) -> Union["FlagsParameter", "NormalParameter"]:
if not ty:
raise ValueError("empty")
if ty == "#":
return FlagsParameter()
if (pos := ty.find("?")) != -1:
ty, flag = ty[pos + 1 :], Flag.from_str(ty[:pos])
else:
flag = None
return NormalParameter(ty=Type.from_str(ty), flag=flag)
@dataclass
class FlagsParameter(BaseParameter):
def __str__(self) -> str:
return "#"
@dataclass
class NormalParameter(BaseParameter):
ty: Type
flag: Optional[Flag]
def __str__(self) -> str:
res = ""
if self.flag is not None:
res += f"{self.flag}?"
res += str(self.ty)
return res

View File

@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Self
@dataclass
class Type:
namespace: List[str]
name: str
bare: bool
generic_ref: bool
generic_arg: Optional[Self]
@classmethod
def from_str(cls, ty: str) -> Self:
stripped = ty.lstrip("!")
ty, generic_ref = stripped, stripped != ty
if (pos := ty.find("<")) != -1:
if not ty.endswith(">"):
raise ValueError("invalid generic")
ty, generic_arg = ty[:pos], Type.from_str(ty[pos + 1 : -1])
else:
generic_arg = None
namespace = ty.split(".")
if not all(namespace):
raise ValueError("empty")
name = namespace.pop()
bare = name[0].islower()
return cls(
namespace=namespace,
name=name,
bare=bare,
generic_ref=generic_ref,
generic_arg=generic_arg,
)
def __str__(self) -> str:
res = ""
for ns in self.namespace:
res += f"{ns}."
if self.generic_ref:
res += "!"
res += self.name
if self.generic_arg is not None:
res += f"<{self.generic_arg}>"
return res
def find_generic_refs(self) -> Iterator[str]:
if self.generic_ref:
yield self.name
if self.generic_arg is not None:
yield from self.generic_arg.find_generic_refs()

View File

@ -0,0 +1,91 @@
from typing import Optional
from pytest import mark, raises
from .ty import Type
def test_empty_simple() -> None:
with raises(ValueError) as e:
Type.from_str("")
e.match("empty")
def test_simple() -> None:
assert Type.from_str("foo") == Type(
namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None
)
@mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."])
def test_check_empty_namespaced(ty: str) -> None:
with raises(ValueError) as e:
Type.from_str(ty)
e.match("empty")
def test_check_namespaced() -> None:
assert Type.from_str("foo.bar.baz") == Type(
namespace=["foo", "bar"],
name="baz",
bare=True,
generic_ref=False,
generic_arg=None,
)
@mark.parametrize(
"ty",
[
"foo",
"Foo.bar",
"!bar",
],
)
def test_bare(ty: str) -> None:
assert Type.from_str(ty).bare
@mark.parametrize(
"ty",
[
"Foo",
"Foo.Bar",
"!foo.Bar",
],
)
def test_bare_not(ty: str) -> None:
assert not Type.from_str(ty).bare
@mark.parametrize(
"ty",
[
"!f",
"!Foo",
"!X",
],
)
def test_generic_ref(ty: str) -> None:
assert Type.from_str(ty).generic_ref
def test_generic_ref_not() -> None:
assert not Type.from_str("f").generic_ref
@mark.parametrize(
("ty", "generic"),
[
("foo.bar", None),
("foo<bar>", "bar"),
("foo<bar.Baz>", "bar.Baz"),
("foo<!bar.Baz>", "!bar.Baz"),
("foo<bar<baz>>", "bar<baz>"),
],
)
def test_generic_arg(ty: str, generic: Optional[str]) -> None:
if generic is None:
assert Type.from_str(ty).generic_arg is None
else:
assert Type.from_str(ty).generic_arg == Type.from_str(generic)

View File

@ -0,0 +1,47 @@
from typing import Iterator, Type
from .tl.definition import Definition
from .utils import remove_tl_comments
DEFINITION_SEP = ";"
CATEGORY_MARKER = "---"
FUNCTIONS_SEP = f"{CATEGORY_MARKER}functions---"
TYPES_SEP = f"{CATEGORY_MARKER}types---"
class TypeDef(Definition):
pass
class FunctionDef(Definition):
pass
def iterate(contents: str) -> Iterator[TypeDef | FunctionDef | Exception]:
contents = remove_tl_comments(contents)
index = 0
cls: Type[TypeDef] | Type[FunctionDef] = TypeDef
while index < len(contents):
if (end := contents.find(DEFINITION_SEP, index)) == -1:
end = len(contents)
definition = contents[index:end].strip()
index = end + len(DEFINITION_SEP)
if not definition:
continue
if definition.startswith(CATEGORY_MARKER):
if definition.startswith(FUNCTIONS_SEP):
cls = FunctionDef
definition = definition[len(FUNCTIONS_SEP) :].strip()
elif definition.startswith(TYPES_SEP):
cls = TypeDef
definition = definition[len(FUNCTIONS_SEP) :].strip()
else:
raise ValueError("bad separator")
try:
yield cls.from_str(definition)
except Exception as e:
yield e

View File

@ -0,0 +1,29 @@
from pytest import raises
from .tl_iterator import FunctionDef, TypeDef, iterate
def test_parse_bad_separator() -> None:
with raises(ValueError) as e:
for _ in iterate("---foo---"):
pass
e.match("bad separator")
def test_parse_file() -> None:
items = list(
iterate(
"""
// leading; comment
first#1 = t; // inline comment
---functions---
second and bad;
third#3 = t;
// trailing comment
"""
)
)
assert len(items) == 3
assert isinstance(items[0], TypeDef) and items[0].id == 1
assert isinstance(items[1], ValueError)
assert isinstance(items[2], FunctionDef) and items[2].id == 3

View File

@ -0,0 +1,20 @@
import re
import zlib
def remove_tl_comments(contents: str) -> str:
return re.sub(r"//[^\n]*(?=\n)", "", contents)
def infer_id(definition: str) -> int:
representation = (
definition.replace(":bytes ", ": string")
.replace("?bytes ", "? string")
.replace("<", " ")
.replace(">", "")
.replace("{", "")
.replace("}", "")
)
representation = re.sub(r" \w+:flags\.\d+\?true", "", representation)
return zlib.crc32(representation.encode("ascii"))

View File

@ -0,0 +1,41 @@
from .utils import infer_id, remove_tl_comments
def test_remove_comments_noop() -> None:
data = "hello\nworld"
assert remove_tl_comments(data) == data
data = " \nhello\nworld\n "
assert remove_tl_comments(data) == data
def test_remove_comments_leading() -> None:
input = " // hello\n world "
expected = " \n world "
assert remove_tl_comments(input) == expected
def test_remove_comments_trailing() -> None:
input = " \nhello \n // world \n \n "
expected = " \nhello \n \n \n "
assert remove_tl_comments(input) == expected
def test_remove_comments_many() -> None:
input = "no\n//yes\nno\n//yes\nno\n"
expected = "no\n\nno\n\nno\n"
assert remove_tl_comments(input) == expected
def test_check_infer_id() -> None:
defn = "rpc_answer_dropped msg_id:long seq_no:int bytes:int = RpcDropAnswer"
assert infer_id(defn) == 0xA43AD8B7
defn = "msgs_ack msg_ids:Vector<long> = MsgsAck"
assert infer_id(defn) == 0x62D6B459
defn = "invokeAfterMsg {X:Type} msg_id:long query:!X = X"
assert infer_id(defn) == 0xCB9F372D
defn = "inputMessagesFilterPhoneCalls flags:# missed:flags.0?true = MessagesFilter"
assert infer_id(defn) == 0x80C99768