From d33e38a391ee99ae48a1f13d26915634a79b3447 Mon Sep 17 00:00:00 2001
From: Erik Wrede <erikwrede@users.noreply.github.com>
Date: Mon, 13 Mar 2023 21:23:28 +0100
Subject: [PATCH] chore: make relay type fields extendable (#1499)

---
 graphene/relay/connection.py            |  71 +++++++++------
 graphene/relay/tests/test_connection.py | 115 +++++++++++++++++++++++-
 2 files changed, 156 insertions(+), 30 deletions(-)

diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py
index 1a4684e5..ea497367 100644
--- a/graphene/relay/connection.py
+++ b/graphene/relay/connection.py
@@ -1,6 +1,7 @@
 import re
 from collections.abc import Iterable
 from functools import partial
+from typing import Type
 
 from graphql_relay import connection_from_array
 
@@ -8,7 +9,28 @@ from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String
 from ..types.field import Field
 from ..types.objecttype import ObjectType, ObjectTypeOptions
 from ..utils.thenables import maybe_thenable
-from .node import is_node
+from .node import is_node, AbstractNode
+
+
+def get_edge_class(
+    connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str
+):
+    edge_class = getattr(connection_class, "Edge", None)
+
+    class EdgeBase:
+        node = Field(_node, description="The item at the end of the edge")
+        cursor = String(required=True, description="A cursor for use in pagination")
+
+    class EdgeMeta:
+        description = f"A Relay edge containing a `{base_name}` and its cursor."
+
+    edge_name = f"{base_name}Edge"
+
+    edge_bases = [edge_class, EdgeBase] if edge_class else [EdgeBase]
+    if not isinstance(edge_class, ObjectType):
+        edge_bases = [*edge_bases, ObjectType]
+
+    return type(edge_name, tuple(edge_bases), {"Meta": EdgeMeta})
 
 
 class PageInfo(ObjectType):
@@ -61,8 +83,9 @@ class Connection(ObjectType):
         abstract = True
 
     @classmethod
-    def __init_subclass_with_meta__(cls, node=None, name=None, **options):
-        _meta = ConnectionOptions(cls)
+    def __init_subclass_with_meta__(cls, node=None, name=None, _meta=None, **options):
+        if not _meta:
+            _meta = ConnectionOptions(cls)
         assert node, f"You have to provide a node in {cls.__name__}.Meta"
         assert isinstance(node, NonNull) or issubclass(
             node, (Scalar, Enum, ObjectType, Interface, Union, NonNull)
@@ -72,39 +95,29 @@ class Connection(ObjectType):
         if not name:
             name = f"{base_name}Connection"
 
-        edge_class = getattr(cls, "Edge", None)
-        _node = node
-
-        class EdgeBase:
-            node = Field(_node, description="The item at the end of the edge")
-            cursor = String(required=True, description="A cursor for use in pagination")
-
-        class EdgeMeta:
-            description = f"A Relay edge containing a `{base_name}` and its cursor."
-
-        edge_name = f"{base_name}Edge"
-        if edge_class:
-            edge_bases = (edge_class, EdgeBase, ObjectType)
-        else:
-            edge_bases = (EdgeBase, ObjectType)
-
-        edge = type(edge_name, edge_bases, {"Meta": EdgeMeta})
-        cls.Edge = edge
-
         options["name"] = name
+
         _meta.node = node
-        _meta.fields = {
-            "page_info": Field(
+
+        if not _meta.fields:
+            _meta.fields = {}
+
+        if "page_info" not in _meta.fields:
+            _meta.fields["page_info"] = Field(
                 PageInfo,
                 name="pageInfo",
                 required=True,
                 description="Pagination data for this connection.",
-            ),
-            "edges": Field(
-                NonNull(List(edge)),
+            )
+
+        if "edges" not in _meta.fields:
+            edge_class = get_edge_class(cls, node, base_name)  # type: ignore
+            cls.Edge = edge_class
+            _meta.fields["edges"] = Field(
+                NonNull(List(edge_class)),
                 description="Contains the nodes in this connection.",
-            ),
-        }
+            )
+
         return super(Connection, cls).__init_subclass_with_meta__(
             _meta=_meta, **options
         )
diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py
index 4015f4b4..d45eea96 100644
--- a/graphene/relay/tests/test_connection.py
+++ b/graphene/relay/tests/test_connection.py
@@ -1,7 +1,15 @@
+import re
+
 from pytest import raises
 
 from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String
-from ..connection import Connection, ConnectionField, PageInfo
+from ..connection import (
+    Connection,
+    ConnectionField,
+    PageInfo,
+    ConnectionOptions,
+    get_edge_class,
+)
 from ..node import Node
 
 
@@ -51,6 +59,111 @@ def test_connection_inherit_abstracttype():
     assert list(fields) == ["page_info", "edges", "extra"]
 
 
+def test_connection_extra_abstract_fields():
+    class ConnectionWithNodes(Connection):
+        class Meta:
+            abstract = True
+
+        @classmethod
+        def __init_subclass_with_meta__(cls, node=None, name=None, **options):
+            _meta = ConnectionOptions(cls)
+
+            _meta.fields = {
+                "nodes": Field(
+                    NonNull(List(node)),
+                    description="Contains all the nodes in this connection.",
+                ),
+            }
+
+            return super(ConnectionWithNodes, cls).__init_subclass_with_meta__(
+                node=node, name=name, _meta=_meta, **options
+            )
+
+    class MyObjectConnection(ConnectionWithNodes):
+        class Meta:
+            node = MyObject
+
+        class Edge:
+            other = String()
+
+    assert MyObjectConnection._meta.name == "MyObjectConnection"
+    fields = MyObjectConnection._meta.fields
+    assert list(fields) == ["nodes", "page_info", "edges"]
+    edge_field = fields["edges"]
+    pageinfo_field = fields["page_info"]
+    nodes_field = fields["nodes"]
+
+    assert isinstance(edge_field, Field)
+    assert isinstance(edge_field.type, NonNull)
+    assert isinstance(edge_field.type.of_type, List)
+    assert edge_field.type.of_type.of_type == MyObjectConnection.Edge
+
+    assert isinstance(pageinfo_field, Field)
+    assert isinstance(pageinfo_field.type, NonNull)
+    assert pageinfo_field.type.of_type == PageInfo
+
+    assert isinstance(nodes_field, Field)
+    assert isinstance(nodes_field.type, NonNull)
+    assert isinstance(nodes_field.type.of_type, List)
+    assert nodes_field.type.of_type.of_type == MyObject
+
+
+def test_connection_override_fields():
+    class ConnectionWithNodes(Connection):
+        class Meta:
+            abstract = True
+
+        @classmethod
+        def __init_subclass_with_meta__(cls, node=None, name=None, **options):
+            _meta = ConnectionOptions(cls)
+            base_name = (
+                re.sub("Connection$", "", name or cls.__name__) or node._meta.name
+            )
+
+            edge_class = get_edge_class(cls, node, base_name)
+
+            _meta.fields = {
+                "page_info": Field(
+                    NonNull(
+                        PageInfo,
+                        name="pageInfo",
+                        required=True,
+                        description="Pagination data for this connection.",
+                    )
+                ),
+                "edges": Field(
+                    NonNull(List(NonNull(edge_class))),
+                    description="Contains the nodes in this connection.",
+                ),
+            }
+
+            return super(ConnectionWithNodes, cls).__init_subclass_with_meta__(
+                node=node, name=name, _meta=_meta, **options
+            )
+
+    class MyObjectConnection(ConnectionWithNodes):
+        class Meta:
+            node = MyObject
+
+    assert MyObjectConnection._meta.name == "MyObjectConnection"
+    fields = MyObjectConnection._meta.fields
+    assert list(fields) == ["page_info", "edges"]
+    edge_field = fields["edges"]
+    pageinfo_field = fields["page_info"]
+
+    assert isinstance(edge_field, Field)
+    assert isinstance(edge_field.type, NonNull)
+    assert isinstance(edge_field.type.of_type, List)
+    assert isinstance(edge_field.type.of_type.of_type, NonNull)
+
+    assert edge_field.type.of_type.of_type.of_type.__name__ == "MyObjectEdge"
+
+    # This page info is NonNull
+    assert isinstance(pageinfo_field, Field)
+    assert isinstance(edge_field.type, NonNull)
+    assert pageinfo_field.type.of_type == PageInfo
+
+
 def test_connection_name():
     custom_name = "MyObjectCustomNameConnection"