types: add option for strict connection types

This commit is contained in:
shrouxm 2023-05-04 13:58:36 -07:00
parent 57cbef6666
commit 3943353e66

View File

@ -13,12 +13,12 @@ from .node import is_node, AbstractNode
def get_edge_class( def get_edge_class(
connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str, strict_types: bool = False
): ):
edge_class = getattr(connection_class, "Edge", None) edge_class = getattr(connection_class, "Edge", None)
class EdgeBase: class EdgeBase:
node = Field(_node, description="The item at the end of the edge") node = Field(NonNull(_node) if strict_types else _node, description="The item at the end of the edge")
cursor = String(required=True, description="A cursor for use in pagination") cursor = String(required=True, description="A cursor for use in pagination")
class EdgeMeta: class EdgeMeta:
@ -83,7 +83,7 @@ class Connection(ObjectType):
abstract = True abstract = True
@classmethod @classmethod
def __init_subclass_with_meta__(cls, node=None, name=None, _meta=None, **options): def __init_subclass_with_meta__(cls, node=None, name=None, strict_types=False, _meta=None, **options):
if not _meta: if not _meta:
_meta = ConnectionOptions(cls) _meta = ConnectionOptions(cls)
assert node, f"You have to provide a node in {cls.__name__}.Meta" assert node, f"You have to provide a node in {cls.__name__}.Meta"
@ -111,10 +111,10 @@ class Connection(ObjectType):
) )
if "edges" not in _meta.fields: if "edges" not in _meta.fields:
edge_class = get_edge_class(cls, node, base_name) # type: ignore edge_class = get_edge_class(cls, node, base_name, strict_types) # type: ignore
cls.Edge = edge_class cls.Edge = edge_class
_meta.fields["edges"] = Field( _meta.fields["edges"] = Field(
NonNull(List(edge_class)), NonNull(List(NonNull(edge_class) if strict_types else edge_class)),
description="Contains the nodes in this connection.", description="Contains the nodes in this connection.",
) )