From 5cb5d9d65abff7a5a5df6347203de87ca2ec7d77 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 15 Aug 2016 23:24:03 -0700 Subject: [PATCH] Improved Relay Connection --- graphene/relay/connection.py | 38 ++++++------- graphene/relay/tests/test_connection.py | 71 ++++++++++++++++++------- 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 88148b00..21c3ea71 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -6,7 +6,7 @@ import six from graphql_relay import connection_from_list -from ..types import Boolean, Int, List, String +from ..types import Boolean, Int, List, String, AbstractType from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.options import Options @@ -55,6 +55,7 @@ class ConnectionMeta(ObjectTypeMeta): node=None, ) options.interfaces = () + options.local_fields = OrderedDict() assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) assert issubclass(options.node, (Node, ObjectType)), ( @@ -66,30 +67,25 @@ class ConnectionMeta(ObjectTypeMeta): options.name = '{}Connection'.format(base_name) edge_class = attrs.pop('Edge', None) - edge_fields = OrderedDict([ - ('node', Field(options.node, description='The item at the end of the edge')), - ('cursor', Field(String, required=True, description='A cursor for use in pagination')) - ]) - edge_attrs = props(edge_class) if edge_class else OrderedDict() - extended_edge_fields = get_fields_in_type(ObjectType, edge_attrs) - edge_fields.update(extended_edge_fields) - edge_meta = type('Meta', (object, ), { - 'fields': edge_fields, - 'name': '{}Edge'.format(base_name) - }) - yank_fields_from_attrs(edge_attrs, extended_edge_fields) - edge = type('Edge', (ObjectType,), dict(edge_attrs, Meta=edge_meta)) - options.local_fields = OrderedDict([ + class EdgeBase(AbstractType): + node = Field(options.node, description='The item at the end of the edge') + cursor = String(required=True, description='A cursor for use in pagination') + + edge_name = '{}Edge'.format(base_name) + if edge_class and issubclass(edge_class, AbstractType): + edge = type(edge_name, (EdgeBase, edge_class, ObjectType, ), {}) + else: + edge = type(edge_name, (EdgeBase, ObjectType, ), props(edge_class) if edge_class else {}) + + cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options, Edge=edge)) + base_fields = OrderedDict([ ('page_info', Field(PageInfo, name='pageInfo', required=True)), ('edges', Field(List(edge))) ]) - typed_fields = get_fields_in_type(ObjectType, attrs) - options.local_fields.update(typed_fields) - options.fields = options.local_fields - yank_fields_from_attrs(attrs, typed_fields) - - return type.__new__(cls, name, bases, dict(attrs, _meta=options, Edge=edge)) + base_fields.update(cls._meta.fields) + cls._meta.fields = base_fields + return cls class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index ece16a48..a1355c4a 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -1,5 +1,5 @@ -from ...types import Field, List, NonNull, ObjectType, Schema, String +from ...types import Field, List, NonNull, ObjectType, String, AbstractType from ..connection import Connection, PageInfo from ..node import Node @@ -15,24 +15,16 @@ class MyObject(ObjectType): pass -class MyObjectConnection(Connection): - extra = String() - - class Meta: - node = MyObject - - class Edge: - other = String() - - -class RootQuery(ObjectType): - my_connection = Field(MyObjectConnection) - - -schema = Schema(query=RootQuery) - - def test_connection(): + class MyObjectConnection(Connection): + extra = String() + + class Meta: + node = MyObject + + class Edge: + other = String() + assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields assert list(fields.keys()) == ['page_info', 'edges', 'extra'] @@ -48,7 +40,27 @@ def test_connection(): assert pageinfo_field.type.of_type == PageInfo +def test_connection_inherit_abstracttype(): + class BaseConnection(AbstractType): + extra = String() + + class MyObjectConnection(BaseConnection, Connection): + class Meta: + node = MyObject + + assert MyObjectConnection._meta.name == 'MyObjectConnection' + fields = MyObjectConnection._meta.fields + assert list(fields.keys()) == ['page_info', 'edges', 'extra'] + + def test_edge(): + class MyObjectConnection(Connection): + class Meta: + node = MyObject + + class Edge: + other = String() + Edge = MyObjectConnection.Edge assert Edge._meta.name == 'MyObjectEdge' edge_fields = Edge._meta.fields @@ -61,6 +73,29 @@ def test_edge(): assert edge_fields['other'].type == String +def test_edge_with_bases(): + class BaseEdge(AbstractType): + extra = String() + + class MyObjectConnection(Connection): + class Meta: + node = MyObject + + class Edge(BaseEdge): + other = String() + + Edge = MyObjectConnection.Edge + assert Edge._meta.name == 'MyObjectEdge' + edge_fields = Edge._meta.fields + assert list(edge_fields.keys()) == ['node', 'cursor', 'extra', 'other'] + + assert isinstance(edge_fields['node'], Field) + assert edge_fields['node'].type == MyObject + + assert isinstance(edge_fields['other'], Field) + assert edge_fields['other'].type == String + + def test_pageinfo(): assert PageInfo._meta.name == 'PageInfo' fields = PageInfo._meta.fields