diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index afc6bbb0..a886c8ec 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -1,2 +1,3 @@ from .node import Node from .mutation import ClientIDMutation +from .connection import Connection diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py new file mode 100644 index 00000000..950a103b --- /dev/null +++ b/graphene/relay/connection.py @@ -0,0 +1,60 @@ +import re +import copy +from functools import partial +import six +from graphql_relay import connection_definitions + +from ..types.field import Field +from ..types.mutation import Mutation, MutationMeta +from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMeta +from ..types.objecttype import ObjectType, ObjectTypeMeta + +from ..utils.props import props + + +class ConnectionMeta(ObjectTypeMeta): + + def get_options(cls, meta): + options = cls.options_class( + meta, + name=None, + description=None, + node=None, + interfaces=[], + abstract=False + ) + options.graphql_type = None + return options + + def construct(cls, bases, attrs): + if not cls._meta.abstract: + Edge = attrs.pop('Edge', None) + edge_fields = props(Edge) if Edge else {} + + edge_fields = {f.name: f for f in ObjectType._extract_local_fields(edge_fields)} + local_fields = cls._extract_local_fields(attrs) + + cls = super(ConnectionMeta, cls).construct(bases, attrs) + if not cls._meta.abstract: + from ..utils.get_graphql_type import get_graphql_type + assert cls._meta.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) + edge, connection = connection_definitions( + name=cls._meta.name or re.sub('Connection$', '', cls.__name__), + node_type=get_graphql_type(cls._meta.node), + resolve_node=cls.resolve_node, + resolve_cursor=cls.resolve_cursor, + + edge_fields=edge_fields, + connection_fields=local_fields, + ) + cls._meta.graphql_type = connection + return cls + + + +class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): + class Meta: + abstract = True + + resolve_node = None + resolve_cursor = None diff --git a/graphene/relay/node.py b/graphene/relay/node.py index db959d7a..820c5190 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -30,6 +30,7 @@ class NodeMeta(ObjectTypeMeta): cls.implements(cls) return cls + class Node(six.with_metaclass(NodeMeta, Interface)): @classmethod diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py new file mode 100644 index 00000000..8a25b306 --- /dev/null +++ b/graphene/relay/tests/test_connection.py @@ -0,0 +1,41 @@ +import pytest + +from ..connection import Connection +from ..node import Node +from ...types import ObjectType, Schema +from ...types.scalars import String +from ...types.field import Field + + +class MyObject(Node, ObjectType): + field = String() + + @classmethod + def get_node(cls, id): + pass + + +class MyObjectConnection(Connection): + class Meta: + node = MyObject + + class Edge: + other = String() + + +class RootQuery(ObjectType): + my_connection = Field(MyObjectConnection) + + +schema = Schema(query=RootQuery) + + +def test_node_good(): + graphql_type = MyObjectConnection._meta.graphql_type + fields = graphql_type.get_fields() + assert 'edges' in fields + assert 'pageInfo' in fields + edge_fields = fields['edges'].type.of_type.get_fields() + assert 'node' in edge_fields + assert edge_fields['node'].type == MyObject._meta.graphql_type + assert 'other' in edge_fields