diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 1a4684e5..7e03376d 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -75,12 +75,23 @@ class Connection(ObjectType): edge_class = getattr(cls, "Edge", None) _node = node - class EdgeBase: - node = Field(_node, description="The item at the end of the edge") - cursor = String(required=True, description="A cursor for use in pagination") - class EdgeMeta: description = f"A Relay edge containing a `{base_name}` and its cursor." + required = False + node_required = False + + if edge_class and hasattr(edge_class, "Meta"): + EdgeMeta = type( + f"{base_name}EdgeMeta", (getattr(edge_class, "Meta"), EdgeMeta), {} + ) + + class EdgeBase: + node = Field( + _node, + required=EdgeMeta.node_required, + description="The item at the end of the edge", + ) + cursor = String(required=True, description="A cursor for use in pagination") edge_name = f"{base_name}Edge" if edge_class: @@ -90,6 +101,7 @@ class Connection(ObjectType): edge = type(edge_name, edge_bases, {"Meta": EdgeMeta}) cls.Edge = edge + edge_field = NonNull(edge) if EdgeMeta.required else edge options["name"] = name _meta.node = node @@ -101,7 +113,7 @@ class Connection(ObjectType): description="Pagination data for this connection.", ), "edges": Field( - NonNull(List(edge)), + NonNull(List(edge_field)), description="Contains the nodes in this connection.", ), } diff --git a/graphene/tests/issues/test_968.py b/graphene/tests/issues/test_968.py new file mode 100644 index 00000000..0d399d44 --- /dev/null +++ b/graphene/tests/issues/test_968.py @@ -0,0 +1,90 @@ +from ...types import ObjectType, Field, Schema, String, NonNull +from ...relay import Connection, ConnectionField + + +def _get_query(_edge_required=False, _node_required=False): + class UserPhoto(ObjectType): + uri = String() + + class UserPhotosConnection(Connection): + class Meta: + node = UserPhoto + + class Edge: + class Meta: + required = _edge_required + node_required = _node_required + + class User(ObjectType): + name = String(required=True) + user_photos = ConnectionField(UserPhotosConnection) + + def resolve_user_photos(self, info): + return [UserPhoto(uri="user-1-uri")] + + class Query(ObjectType): + user = Field(User) + + def resolve_user(self, info): + return User(name="user-1") + + return Query + + +def _get_edge_field(schema): + return schema.query.user._type.user_photos._type._meta.fields[ + "edges" + ]._type.of_type.of_type + + +def test_required_nonnull_edge_only(): + """ + Test that elements in the edge are required + """ + schema = Schema(query=_get_query(True)) + + # Edge + assert isinstance( + _get_edge_field(schema), + NonNull, + ) + # Node + assert not isinstance( + _get_edge_field(schema).of_type.node._type, + NonNull, + ) + + +def test_required_nonnull_node_only(): + """ + Test that elements in the edge are required + """ + schema = Schema(query=_get_query(False, True)) + + # Edge + assert not isinstance( + _get_edge_field(schema), + NonNull, + ) + # Node + assert isinstance( + _get_edge_field(schema).node._type, + NonNull, + ) + + +def test_support_null_nodes(): + """ + Test that elements in the edge are required + """ + schema = Schema(query=_get_query()) + + assert not isinstance( + _get_edge_field(schema), + NonNull, + ) + + assert not isinstance( + _get_edge_field(schema).node._type, + NonNull, + )