From a512071862bfb771986b3833fea4baaca1327e81 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 2 Mar 2023 22:58:34 +0100 Subject: [PATCH] chore: make pageinfo and edge types overridable --- graphene/relay/connection.py | 73 +++++++++++++------------ graphene/relay/tests/test_connection.py | 66 +++++++++++++++++++++- 2 files changed, 104 insertions(+), 35 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 8216842f..0b9ef61e 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -1,6 +1,7 @@ import re from collections.abc import Iterable from functools import partial +from typing import Type from graphql_relay import connection_from_array @@ -8,7 +9,28 @@ from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeOptions from ..utils.thenables import maybe_thenable -from .node import is_node +from .node import is_node, AbstractNode + + +def get_edge_class( + connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str +): + edge_class = getattr(connection_class, "Edge", None) + + class EdgeBase: + node = Field(_node, description="The item at the end of the edge") + cursor = String(required=True, description="A cursor for use in pagination") + + class EdgeMeta: + description = f"A Relay edge containing a `{base_name}` and its cursor." + + edge_name = f"{base_name}Edge" + if edge_class: + edge_bases = (edge_class, EdgeBase, ObjectType) + else: + edge_bases = (EdgeBase, ObjectType) + + return type(edge_name, edge_bases, {"Meta": EdgeMeta}) class PageInfo(ObjectType): @@ -73,25 +95,6 @@ class Connection(ObjectType): if not name: name = f"{base_name}Connection" - edge_class = getattr(cls, "Edge", None) - _node = node - - class EdgeBase: - node = Field(_node, description="The item at the end of the edge") - cursor = String(required=True, description="A cursor for use in pagination") - - class EdgeMeta: - description = f"A Relay edge containing a `{base_name}` and its cursor." - - edge_name = f"{base_name}Edge" - if edge_class: - edge_bases = (edge_class, EdgeBase, ObjectType) - else: - edge_bases = (EdgeBase, ObjectType) - - edge = type(edge_name, edge_bases, {"Meta": EdgeMeta}) - cls.Edge = edge - options["name"] = name _meta.node = node @@ -99,20 +102,22 @@ class Connection(ObjectType): if not _meta.fields: _meta.fields = {} - _meta.fields.update( - { - "page_info": Field( - PageInfo, - name="pageInfo", - required=True, - description="Pagination data for this connection.", - ), - "edges": Field( - NonNull(List(edge)), - description="Contains the nodes in this connection.", - ), - } - ) + if "page_info" not in _meta.fields: + _meta.fields["page_info"] = Field( + PageInfo, + name="pageInfo", + required=True, + description="Pagination data for this connection.", + ) + + if "edges" not in _meta.fields: + edge_class = get_edge_class(cls, node, base_name) # type: ignore + cls.Edge = edge_class + _meta.fields["edges"] = Field( + NonNull(List(edge_class)), + description="Contains the nodes in this connection.", + ) + return super(Connection, cls).__init_subclass_with_meta__( _meta=_meta, **options ) diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index c7d67e5b..d45eea96 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -1,7 +1,15 @@ +import re + from pytest import raises from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String -from ..connection import Connection, ConnectionField, PageInfo, ConnectionOptions +from ..connection import ( + Connection, + ConnectionField, + PageInfo, + ConnectionOptions, + get_edge_class, +) from ..node import Node @@ -100,6 +108,62 @@ def test_connection_extra_abstract_fields(): assert nodes_field.type.of_type.of_type == MyObject +def test_connection_override_fields(): + class ConnectionWithNodes(Connection): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, node=None, name=None, **options): + _meta = ConnectionOptions(cls) + base_name = ( + re.sub("Connection$", "", name or cls.__name__) or node._meta.name + ) + + edge_class = get_edge_class(cls, node, base_name) + + _meta.fields = { + "page_info": Field( + NonNull( + PageInfo, + name="pageInfo", + required=True, + description="Pagination data for this connection.", + ) + ), + "edges": Field( + NonNull(List(NonNull(edge_class))), + description="Contains the nodes in this connection.", + ), + } + + return super(ConnectionWithNodes, cls).__init_subclass_with_meta__( + node=node, name=name, _meta=_meta, **options + ) + + class MyObjectConnection(ConnectionWithNodes): + class Meta: + node = MyObject + + assert MyObjectConnection._meta.name == "MyObjectConnection" + fields = MyObjectConnection._meta.fields + assert list(fields) == ["page_info", "edges"] + edge_field = fields["edges"] + pageinfo_field = fields["page_info"] + + assert isinstance(edge_field, Field) + assert isinstance(edge_field.type, NonNull) + assert isinstance(edge_field.type.of_type, List) + assert isinstance(edge_field.type.of_type.of_type, NonNull) + + assert edge_field.type.of_type.of_type.of_type.__name__ == "MyObjectEdge" + + # This page info is NonNull + assert isinstance(pageinfo_field, Field) + assert isinstance(edge_field.type, NonNull) + assert pageinfo_field.type.of_type == PageInfo + + def test_connection_name(): custom_name = "MyObjectCustomNameConnection"