Telethon/generator/tests/generator_test.py

115 lines
3.7 KiB
Python
Raw Normal View History

2023-07-03 20:19:20 +03:00
from typing import List, Optional
from telethon_generator.codegen import FakeFs, generate
from telethon_generator.tl_parser import Definition, ParsedTl, parse_tl_file
2023-07-03 20:19:20 +03:00
def get_definitions(contents: str) -> List[Definition]:
return [defn for defn in parse_tl_file(contents) if not isinstance(defn, Exception)]
def gen_py_code(
*,
typedefs: Optional[List[Definition]] = None,
functiondefs: Optional[List[Definition]] = None,
) -> str:
fs = FakeFs()
generate(
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
)
generated = bytearray()
for path, data in fs._files.items():
if path.stem not in ("__init__", "layer"):
generated += f"# {path}\n".encode("utf-8")
generated += data
data += b"\n"
return str(generated, "utf-8")
def test_generic_functions_use_bytes_parameters() -> None:
definitions = get_definitions(
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
)
result = gen_py_code(functiondefs=definitions)
assert "invoke_with_layer" in result
assert "query: bytes" in result
assert "buffer += query" in result
def test_recursive_direct() -> None:
definitions = get_definitions("textBold#6724abc4 text:RichText = RichText;")
result = gen_py_code(typedefs=definitions)
assert "text: abcs.RichText" in result
assert "read_serializable" in result
assert "write_boxed_to" in result
def test_recursive_indirect() -> None:
definitions = get_definitions(
"""
messageExtendedMedia#ee479c64 media:MessageMedia = MessageExtendedMedia;
messageMediaInvoice#f6a548d3 flags:# extended_media:flags.4?MessageExtendedMedia = MessageMedia;
"""
)
result = gen_py_code(typedefs=definitions)
assert "media: abcs.MessageMedia" in result
assert "extended_media: Optional[abcs.MessageExtendedMedia])" in result
assert "write_boxed_to" in result
assert "._write_to" not in result
assert "read_serializable" in result
def test_recursive_no_hang() -> None:
definitions = get_definitions(
"""
inputUserFromMessage#1da448e2 peer:InputPeer msg_id:int user_id:long = InputUser;
inputPeerUserFromMessage#a87b0a1c peer:InputPeer msg_id:int user_id:long = InputPeer;
"""
)
gen_py_code(typedefs=definitions)
def test_recursive_vec() -> None:
definitions = get_definitions(
"""
jsonObjectValue#c0de1bd9 key:string value:JSONValue = JSONObjectValue;
jsonArray#f7444763 value:Vector<JSONValue> = JSONValue;
jsonObject#99c1d49d value:Vector<JSONObjectValue> = JSONValue;
"""
)
result = gen_py_code(typedefs=definitions)
2023-07-09 22:16:55 +03:00
assert "value: List[abcs.JsonObjectValue]" in result
def test_object_blob_special_case() -> None:
definitions = get_definitions(
"""
rpc_result#f35c6d01 req_msg_id:long result:Object = RpcResult;
"""
)
result = gen_py_code(typedefs=definitions)
assert "reader.read_remaining()" in result
def test_object_blob_with_prefix_special_case() -> None:
definitions = get_definitions(
"""
message msg_id:long seqno:int bytes:int body:Object = Message;
"""
)
result = gen_py_code(typedefs=definitions)
assert "reader.read(_bytes)" in result
2023-08-30 14:44:42 +03:00
def test_bool_mapped_from_int() -> None:
definitions = get_definitions(
"""
contact#145ade0b user_id:long mutual:Bool = Contact;
"""
)
result = gen_py_code(typedefs=definitions)
assert "_mutual in (0x997275b5, 0xbc799737)" in result
assert "=_mutual == 0x997275b5" in result
assert "0x997275b5 if self.mutual else 0xbc799737" in result