chore: make pageinfo and edge types overridable

This commit is contained in:
Erik Wrede 2023-03-02 22:58:34 +01:00
parent 828e951c0e
commit a512071862
2 changed files with 104 additions and 35 deletions

View File

@ -1,6 +1,7 @@
import re import re
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Type
from graphql_relay import connection_from_array from graphql_relay import connection_from_array
@ -8,7 +9,28 @@ from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeOptions from ..types.objecttype import ObjectType, ObjectTypeOptions
from ..utils.thenables import maybe_thenable from ..utils.thenables import maybe_thenable
from .node import is_node from .node import is_node, AbstractNode
def get_edge_class(
connection_class: Type["Connection"], _node: Type[AbstractNode], base_name: str
):
edge_class = getattr(connection_class, "Edge", None)
class EdgeBase:
node = Field(_node, description="The item at the end of the edge")
cursor = String(required=True, description="A cursor for use in pagination")
class EdgeMeta:
description = f"A Relay edge containing a `{base_name}` and its cursor."
edge_name = f"{base_name}Edge"
if edge_class:
edge_bases = (edge_class, EdgeBase, ObjectType)
else:
edge_bases = (EdgeBase, ObjectType)
return type(edge_name, edge_bases, {"Meta": EdgeMeta})
class PageInfo(ObjectType): class PageInfo(ObjectType):
@ -73,25 +95,6 @@ class Connection(ObjectType):
if not name: if not name:
name = f"{base_name}Connection" name = f"{base_name}Connection"
edge_class = getattr(cls, "Edge", None)
_node = node
class EdgeBase:
node = Field(_node, description="The item at the end of the edge")
cursor = String(required=True, description="A cursor for use in pagination")
class EdgeMeta:
description = f"A Relay edge containing a `{base_name}` and its cursor."
edge_name = f"{base_name}Edge"
if edge_class:
edge_bases = (edge_class, EdgeBase, ObjectType)
else:
edge_bases = (EdgeBase, ObjectType)
edge = type(edge_name, edge_bases, {"Meta": EdgeMeta})
cls.Edge = edge
options["name"] = name options["name"] = name
_meta.node = node _meta.node = node
@ -99,20 +102,22 @@ class Connection(ObjectType):
if not _meta.fields: if not _meta.fields:
_meta.fields = {} _meta.fields = {}
_meta.fields.update( if "page_info" not in _meta.fields:
{ _meta.fields["page_info"] = Field(
"page_info": Field(
PageInfo, PageInfo,
name="pageInfo", name="pageInfo",
required=True, required=True,
description="Pagination data for this connection.", description="Pagination data for this connection.",
),
"edges": Field(
NonNull(List(edge)),
description="Contains the nodes in this connection.",
),
}
) )
if "edges" not in _meta.fields:
edge_class = get_edge_class(cls, node, base_name) # type: ignore
cls.Edge = edge_class
_meta.fields["edges"] = Field(
NonNull(List(edge_class)),
description="Contains the nodes in this connection.",
)
return super(Connection, cls).__init_subclass_with_meta__( return super(Connection, cls).__init_subclass_with_meta__(
_meta=_meta, **options _meta=_meta, **options
) )

View File

@ -1,7 +1,15 @@
import re
from pytest import raises from pytest import raises
from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo, ConnectionOptions from ..connection import (
Connection,
ConnectionField,
PageInfo,
ConnectionOptions,
get_edge_class,
)
from ..node import Node from ..node import Node
@ -100,6 +108,62 @@ def test_connection_extra_abstract_fields():
assert nodes_field.type.of_type.of_type == MyObject assert nodes_field.type.of_type.of_type == MyObject
def test_connection_override_fields():
class ConnectionWithNodes(Connection):
class Meta:
abstract = True
@classmethod
def __init_subclass_with_meta__(cls, node=None, name=None, **options):
_meta = ConnectionOptions(cls)
base_name = (
re.sub("Connection$", "", name or cls.__name__) or node._meta.name
)
edge_class = get_edge_class(cls, node, base_name)
_meta.fields = {
"page_info": Field(
NonNull(
PageInfo,
name="pageInfo",
required=True,
description="Pagination data for this connection.",
)
),
"edges": Field(
NonNull(List(NonNull(edge_class))),
description="Contains the nodes in this connection.",
),
}
return super(ConnectionWithNodes, cls).__init_subclass_with_meta__(
node=node, name=name, _meta=_meta, **options
)
class MyObjectConnection(ConnectionWithNodes):
class Meta:
node = MyObject
assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields
assert list(fields) == ["page_info", "edges"]
edge_field = fields["edges"]
pageinfo_field = fields["page_info"]
assert isinstance(edge_field, Field)
assert isinstance(edge_field.type, NonNull)
assert isinstance(edge_field.type.of_type, List)
assert isinstance(edge_field.type.of_type.of_type, NonNull)
assert edge_field.type.of_type.of_type.of_type.__name__ == "MyObjectEdge"
# This page info is NonNull
assert isinstance(pageinfo_field, Field)
assert isinstance(edge_field.type, NonNull)
assert pageinfo_field.type.of_type == PageInfo
def test_connection_name(): def test_connection_name():
custom_name = "MyObjectCustomNameConnection" custom_name = "MyObjectCustomNameConnection"