diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 1a4684e5..ea497367 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -1,6 +1,7 @@ import re from collections.abc import Iterable from functools import partial +from typing import Type from graphql_relay import connection_from_array @@ -8,7 +9,28 @@ from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeOptions from ..utils.thenables import maybe_thenable -from .node import is_node +from .node import is_node, AbstractNode + + +def get_edge_class( + connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str +): + edge_class = getattr(connection_class, "Edge", None) + + class EdgeBase: + node = Field(_node, description="The item at the end of the edge") + cursor = String(required=True, description="A cursor for use in pagination") + + class EdgeMeta: + description = f"A Relay edge containing a `{base_name}` and its cursor." + + edge_name = f"{base_name}Edge" + + edge_bases = [edge_class, EdgeBase] if edge_class else [EdgeBase] + if not isinstance(edge_class, ObjectType): + edge_bases = [*edge_bases, ObjectType] + + return type(edge_name, tuple(edge_bases), {"Meta": EdgeMeta}) class PageInfo(ObjectType): @@ -61,8 +83,9 @@ class Connection(ObjectType): abstract = True @classmethod - def __init_subclass_with_meta__(cls, node=None, name=None, **options): - _meta = ConnectionOptions(cls) + def __init_subclass_with_meta__(cls, node=None, name=None, _meta=None, **options): + if not _meta: + _meta = ConnectionOptions(cls) assert node, f"You have to provide a node in {cls.__name__}.Meta" assert isinstance(node, NonNull) or issubclass( node, (Scalar, Enum, ObjectType, Interface, Union, NonNull) @@ -72,39 +95,29 @@ class Connection(ObjectType): if not name: name = f"{base_name}Connection" - edge_class = getattr(cls, "Edge", None) - _node = node - - class EdgeBase: - node = Field(_node, description="The item at the end of the edge") - cursor = String(required=True, description="A cursor for use in pagination") - - class EdgeMeta: - description = f"A Relay edge containing a `{base_name}` and its cursor." - - edge_name = f"{base_name}Edge" - if edge_class: - edge_bases = (edge_class, EdgeBase, ObjectType) - else: - edge_bases = (EdgeBase, ObjectType) - - edge = type(edge_name, edge_bases, {"Meta": EdgeMeta}) - cls.Edge = edge - options["name"] = name + _meta.node = node - _meta.fields = { - "page_info": Field( + + if not _meta.fields: + _meta.fields = {} + + if "page_info" not in _meta.fields: + _meta.fields["page_info"] = Field( PageInfo, name="pageInfo", required=True, description="Pagination data for this connection.", - ), - "edges": Field( - NonNull(List(edge)), + ) + + if "edges" not in _meta.fields: + edge_class = get_edge_class(cls, node, base_name) # type: ignore + cls.Edge = edge_class + _meta.fields["edges"] = Field( + NonNull(List(edge_class)), description="Contains the nodes in this connection.", - ), - } + ) + return super(Connection, cls).__init_subclass_with_meta__( _meta=_meta, **options ) diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index 4015f4b4..d45eea96 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -1,7 +1,15 @@ +import re + from pytest import raises from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String -from ..connection import Connection, ConnectionField, PageInfo +from ..connection import ( + Connection, + ConnectionField, + PageInfo, + ConnectionOptions, + get_edge_class, +) from ..node import Node @@ -51,6 +59,111 @@ def test_connection_inherit_abstracttype(): assert list(fields) == ["page_info", "edges", "extra"] +def test_connection_extra_abstract_fields(): + class ConnectionWithNodes(Connection): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, node=None, name=None, **options): + _meta = ConnectionOptions(cls) + + _meta.fields = { + "nodes": Field( + NonNull(List(node)), + description="Contains all the nodes in this connection.", + ), + } + + return super(ConnectionWithNodes, cls).__init_subclass_with_meta__( + node=node, name=name, _meta=_meta, **options + ) + + class MyObjectConnection(ConnectionWithNodes): + class Meta: + node = MyObject + + class Edge: + other = String() + + assert MyObjectConnection._meta.name == "MyObjectConnection" + fields = MyObjectConnection._meta.fields + assert list(fields) == ["nodes", "page_info", "edges"] + edge_field = fields["edges"] + pageinfo_field = fields["page_info"] + nodes_field = fields["nodes"] + + assert isinstance(edge_field, Field) + assert isinstance(edge_field.type, NonNull) + assert isinstance(edge_field.type.of_type, List) + assert edge_field.type.of_type.of_type == MyObjectConnection.Edge + + assert isinstance(pageinfo_field, Field) + assert isinstance(pageinfo_field.type, NonNull) + assert pageinfo_field.type.of_type == PageInfo + + assert isinstance(nodes_field, Field) + assert isinstance(nodes_field.type, NonNull) + assert isinstance(nodes_field.type.of_type, List) + assert nodes_field.type.of_type.of_type == MyObject + + +def test_connection_override_fields(): + class ConnectionWithNodes(Connection): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, node=None, name=None, **options): + _meta = ConnectionOptions(cls) + base_name = ( + re.sub("Connection$", "", name or cls.__name__) or node._meta.name + ) + + edge_class = get_edge_class(cls, node, base_name) + + _meta.fields = { + "page_info": Field( + NonNull( + PageInfo, + name="pageInfo", + required=True, + description="Pagination data for this connection.", + ) + ), + "edges": Field( + NonNull(List(NonNull(edge_class))), + description="Contains the nodes in this connection.", + ), + } + + return super(ConnectionWithNodes, cls).__init_subclass_with_meta__( + node=node, name=name, _meta=_meta, **options + ) + + class MyObjectConnection(ConnectionWithNodes): + class Meta: + node = MyObject + + assert MyObjectConnection._meta.name == "MyObjectConnection" + fields = MyObjectConnection._meta.fields + assert list(fields) == ["page_info", "edges"] + edge_field = fields["edges"] + pageinfo_field = fields["page_info"] + + assert isinstance(edge_field, Field) + assert isinstance(edge_field.type, NonNull) + assert isinstance(edge_field.type.of_type, List) + assert isinstance(edge_field.type.of_type.of_type, NonNull) + + assert edge_field.type.of_type.of_type.of_type.__name__ == "MyObjectEdge" + + # This page info is NonNull + assert isinstance(pageinfo_field, Field) + assert isinstance(edge_field.type, NonNull) + assert pageinfo_field.type.of_type == PageInfo + + def test_connection_name(): custom_name = "MyObjectCustomNameConnection"