Use Sequence as input in generated functions

This commit is contained in:
Lonami Exo 2023-11-07 19:40:09 +01:00
parent 6047c689ca
commit 4cc6ecc39b
2 changed files with 16 additions and 7 deletions

View File

@ -94,7 +94,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
if type_path not in fs: if type_path not in fs:
writer.write("import struct") writer.write("import struct")
writer.write("from typing import List, Optional, Self") writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from .. import abcs") writer.write("from .. import abcs")
writer.write("from ..core import Reader, Serializable, serialize_bytes_to") writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
@ -118,7 +118,8 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def __init__() # def __init__()
if property_params: if property_params:
params = "".join( params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params f", {p.name}: {param_type_fmt(p.ty, immutable=False)}"
for p in property_params
) )
writer.write(f" def __init__(_s, *{params}) -> None:") writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params: for p in property_params:
@ -158,15 +159,20 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
if function_path not in fs: if function_path not in fs:
writer.write("import struct") writer.write("import struct")
writer.write("from typing import List, Optional, Self") writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from .. import abcs") writer.write("from .. import abcs")
writer.write("from ..core import Request, serialize_bytes_to") writer.write("from ..core import Request, serialize_bytes_to")
# def name(params, ...) # def name(params, ...)
required_params = [p for p in functiondef.params if not is_computed(p.ty)] required_params = [p for p in functiondef.params if not is_computed(p.ty)]
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params) params = "".join(
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
for p in required_params
)
star = "*" if params else "" star = "*" if params else ""
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None)) return_ty = param_type_fmt(
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
)
writer.write( writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:" f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
) )

View File

@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
return f"abcs.{ns}{to_class_name(ty.name)}" return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter) -> str: def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
if isinstance(ty, FlagsParameter): if isinstance(ty, FlagsParameter):
return "int" return "int"
elif not isinstance(ty, NormalParameter): elif not isinstance(ty, NormalParameter):
@ -104,7 +104,10 @@ def param_type_fmt(ty: BaseParameter) -> str:
res = "bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty) res = "bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
if ty.ty.generic_arg: if ty.ty.generic_arg:
res = f"List[{res}]" if immutable:
res = f"Sequence[{res}]"
else:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true": if ty.flag and ty.ty.name != "true":
res = f"Optional[{res}]" res = f"Optional[{res}]"