Use Sequence as input in generated functions

This commit is contained in:
Lonami Exo 2023-11-07 19:40:09 +01:00
parent 6047c689ca
commit 4cc6ecc39b
2 changed files with 16 additions and 7 deletions

View File

@ -94,7 +94,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
if type_path not in fs:
writer.write("import struct")
writer.write("from typing import List, Optional, Self")
writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from .. import abcs")
writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
@ -118,7 +118,8 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def __init__()
if property_params:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
f", {p.name}: {param_type_fmt(p.ty, immutable=False)}"
for p in property_params
)
writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params:
@ -158,15 +159,20 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
if function_path not in fs:
writer.write("import struct")
writer.write("from typing import List, Optional, Self")
writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from .. import abcs")
writer.write("from ..core import Request, serialize_bytes_to")
# def name(params, ...)
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
params = "".join(
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
for p in required_params
)
star = "*" if params else ""
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
return_ty = param_type_fmt(
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
)
writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
)

View File

@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter) -> str:
def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
if isinstance(ty, FlagsParameter):
return "int"
elif not isinstance(ty, NormalParameter):
@ -104,6 +104,9 @@ def param_type_fmt(ty: BaseParameter) -> str:
res = "bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
if ty.ty.generic_arg:
if immutable:
res = f"Sequence[{res}]"
else:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true":