diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index ea497367..cc7d2da0 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -13,12 +13,18 @@ from .node import is_node, AbstractNode def get_edge_class( - connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str + connection_class: Type["Connection"], + _node: Type[AbstractNode], + base_name: str, + strict_types: bool = False, ): edge_class = getattr(connection_class, "Edge", None) class EdgeBase: - node = Field(_node, description="The item at the end of the edge") + node = Field( + NonNull(_node) if strict_types else _node, + description="The item at the end of the edge", + ) cursor = String(required=True, description="A cursor for use in pagination") class EdgeMeta: @@ -83,7 +89,9 @@ class Connection(ObjectType): abstract = True @classmethod - def __init_subclass_with_meta__(cls, node=None, name=None, _meta=None, **options): + def __init_subclass_with_meta__( + cls, node=None, name=None, strict_types=False, _meta=None, **options + ): if not _meta: _meta = ConnectionOptions(cls) assert node, f"You have to provide a node in {cls.__name__}.Meta" @@ -111,10 +119,10 @@ class Connection(ObjectType): ) if "edges" not in _meta.fields: - edge_class = get_edge_class(cls, node, base_name) # type: ignore + edge_class = get_edge_class(cls, node, base_name, strict_types) # type: ignore cls.Edge = edge_class _meta.fields["edges"] = Field( - NonNull(List(edge_class)), + NonNull(List(NonNull(edge_class) if strict_types else edge_class)), description="Contains the nodes in this connection.", ) diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index d45eea96..9c8b89d1 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -299,3 +299,20 @@ def test_connectionfield_required(): executed = schema.execute("{ testConnection { edges { cursor } } }") assert not executed.errors assert executed.data == {"testConnection": {"edges": []}} + + +def test_connectionfield_strict_types(): + class MyObjectConnection(Connection): + class Meta: + node = MyObject + strict_types = True + + connection_field = ConnectionField(MyObjectConnection) + edges_field_type = connection_field.type._meta.fields["edges"].type + assert isinstance(edges_field_type, NonNull) + + edges_list_element_type = edges_field_type.of_type.of_type + assert isinstance(edges_list_element_type, NonNull) + + node_field = edges_list_element_type.of_type._meta.fields["node"] + assert isinstance(node_field.type, NonNull)