From d7c26e37bc6b79e9afa6ef69d0523de43a4b80a9 Mon Sep 17 00:00:00 2001 From: Markus Padourek Date: Wed, 14 Sep 2016 17:24:33 +0100 Subject: [PATCH] Fixed returning of connection and promise for ConnectionField. Fixed having complete custom edges. --- graphene/relay/connection.py | 96 ++++++++------- graphene/relay/tests/test_connection.py | 72 +++++++++-- graphene/relay/tests/test_connection_query.py | 114 ++++++++++++++++-- 3 files changed, 220 insertions(+), 62 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 629faf77..21b58ce9 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -2,6 +2,7 @@ import re from collections import Iterable, OrderedDict from functools import partial +from promise import Promise import six from fastcache import clru_cache @@ -56,37 +57,44 @@ class ConnectionMeta(ObjectTypeMeta): ) options.interfaces = () options.local_fields = OrderedDict() - - assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) - assert issubclass(options.node, (Node, ObjectType)), ( - 'Received incompatible node "{}" for Connection {}.' - ).format(options.node, name) - base_name = re.sub('Connection$', '', name) + + if attrs.get('edges'): + edges = attrs.get('edges') + edge = edges.of_type + else: + assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) + assert issubclass(options.node, (Node, ObjectType)), ( + 'Received incompatible node "{}" for Connection {}.' + ).format(options.node, name) + + base_name = re.sub('Connection$', '', name) + + edge_class = attrs.pop('Edge', None) + + edge_attrs = { + 'node': Field( + options.node, description='The item at the end of the edge'), + 'cursor': Edge._meta.fields['cursor'] + } + + edge_name = '{}Edge'.format(base_name) + if edge_class and issubclass(edge_class, AbstractType): + edge = type(edge_name, (edge_class, ObjectType, ), edge_attrs) + else: + additional_attrs = props(edge_class) if edge_class else {} + edge_attrs.update(additional_attrs) + edge = type(edge_name, (ObjectType, ), edge_attrs) + + edges = List(edge) + if not options.name: options.name = '{}Connection'.format(base_name) - edge_class = attrs.pop('Edge', None) - - edge_attrs = { - 'node': Field( - options.node, description='The item at the end of the edge'), - 'cursor': Edge._meta.fields['cursor'] - } - - edge_name = '{}Edge'.format(base_name) - if edge_class and issubclass(edge_class, AbstractType): - edge = type(edge_name, (edge_class, ObjectType, ), edge_attrs) - else: - additional_attrs = props(edge_class) if edge_class else {} - edge_attrs.update(additional_attrs) - edge = type(edge_name, (ObjectType, ), edge_attrs) - attrs.update({ 'page_info': Field(PageInfo, name='pageInfo', required=True), - 'edges': List(edge), + 'edges': edges, }) - attrs = dict(attrs, _meta=options, Edge=edge) return ObjectTypeMeta.__new__(cls, name, bases, attrs) @@ -111,7 +119,7 @@ class Edge(AbstractType): def is_connection(gql_type): '''Checks if a type is a connection. Taken directly from the spec definition: https://facebook.github.io/relay/graphql/connections.htm#sec-Connection-Types''' - return gql_type._meta.name.endswith('Connection') + return gql_type._meta.name.endswith('Connection') if hasattr(gql_type, '_meta') else False class IterableConnectionField(Field): @@ -134,22 +142,28 @@ class IterableConnectionField(Field): @staticmethod def connection_resolver(resolver, connection, root, args, context, info): - iterable = resolver(root, args, context, info) - assert isinstance(iterable, Iterable), ( - 'Resolved value from the connection field have to be iterable. ' - 'Received "{}"' - ).format(iterable) - # raise Exception('sdsdfsdfsdfsdsdf') - connection = connection_from_list( - iterable, - args, - connection_type=connection, - edge_type=connection.Edge, - pageinfo_type=PageInfo - ) - # print(connection) - connection.iterable = iterable - return connection + resolved = Promise.resolve(resolver(root, args, context, info)) + + def handle_connection_and_list(result): + if is_connection(result): + return result + else: + assert isinstance(result, Iterable), ( + 'Resolved value from the connection field have to be iterable. ' + 'Received "{}"' + ).format(result) + + resolved_connection = connection_from_list( + result, + args, + connection_type=connection, + edge_type=connection.Edge, + pageinfo_type=PageInfo + ) + resolved_connection.iterable = result + return resolved_connection + + return resolved.then(handle_connection_and_list) def get_resolver(self, parent_resolver): resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index 94849311..94f0456b 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -1,5 +1,5 @@ from ...types import Field, List, NonNull, ObjectType, String, AbstractType -from ..connection import Connection, PageInfo +from ..connection import Connection, PageInfo, Edge from ..node import Node @@ -10,7 +10,7 @@ class MyObject(ObjectType): field = String() -def xtest_connection(): +def test_connection(): class MyObjectConnection(Connection): extra = String() @@ -22,7 +22,7 @@ def xtest_connection(): assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields - assert list(fields.keys()) == ['extra', 'page_info', 'edges'] + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] edge_field = fields['edges'] pageinfo_field = fields['page_info'] @@ -35,7 +35,52 @@ def xtest_connection(): assert pageinfo_field.type.of_type == PageInfo -def xtest_connection_inherit_abstracttype(): +def test_multiple_connection_edges_are_not_the_same(): + class MyObjectConnection(Connection): + extra = String() + + class Meta: + node = MyObject + + class Edge: + other = String() + + class MyOtherObjectConnection(Connection): + class Meta: + node = MyObject + + class Edge: + other = String() + + assert MyObjectConnection.Edge != MyOtherObjectConnection.Edge + assert MyObjectConnection.Edge._meta.name != MyOtherObjectConnection.Edge._meta.name + + +def test_create_connection_with_custom_edge_type(): + class MyEdge(Edge): + node = Field(MyObject) + + class MyObjectConnection(Connection): + extra = String() + edges = List(MyEdge) + + assert MyObjectConnection.Edge == MyEdge + assert MyObjectConnection._meta.name == 'MyObjectConnection' + fields = MyObjectConnection._meta.fields + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] + edge_field = fields['edges'] + pageinfo_field = fields['page_info'] + + assert isinstance(edge_field, Field) + assert isinstance(edge_field.type, List) + assert edge_field.type.of_type == MyObjectConnection.Edge + + assert isinstance(pageinfo_field, Field) + assert isinstance(pageinfo_field.type, NonNull) + assert pageinfo_field.type.of_type == PageInfo + + +def test_connection_inherit_abstracttype(): class BaseConnection(AbstractType): extra = String() @@ -45,20 +90,21 @@ def xtest_connection_inherit_abstracttype(): assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields - assert list(fields.keys()) == ['extra', 'page_info', 'edges'] + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] -def xtest_defaul_connection_for_type(): +def test_defaul_connection_for_type(): MyObjectConnection = Connection.for_type(MyObject) assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields - assert list(fields.keys()) == ['page_info', 'edges'] + assert list(fields.keys()) == ['edges', 'page_info'] -def xtest_defaul_connection_for_type_returns_same_Connection(): +def test_defaul_connection_for_type_returns_same_Connection(): assert Connection.for_type(MyObject) == Connection.for_type(MyObject) -def xtest_edge(): + +def test_edge(): class MyObjectConnection(Connection): class Meta: node = MyObject @@ -101,13 +147,13 @@ def test_edge_with_bases(): assert edge_fields['other'].type == String -def xtest_pageinfo(): +def test_pageinfo(): assert PageInfo._meta.name == 'PageInfo' fields = PageInfo._meta.fields assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor'] -def xtest_edge_for_node_type(): +def test_edge_for_node_type(): edge = Connection.for_type(MyObject).Edge assert edge._meta.name == 'MyObjectEdge' @@ -118,7 +164,7 @@ def xtest_edge_for_node_type(): assert edge_fields['node'].type == MyObject -def xtest_edge_for_object_type(): +def test_edge_for_object_type(): class MyObject(ObjectType): field = String() @@ -132,5 +178,5 @@ def xtest_edge_for_object_type(): assert edge_fields['node'].type == MyObject -def xtest_edge_for_type_returns_same_edge(): +def test_edge_for_type_returns_same_edge(): assert Connection.for_type(MyObject).Edge == Connection.for_type(MyObject).Edge diff --git a/graphene/relay/tests/test_connection_query.py b/graphene/relay/tests/test_connection_query.py index b2c10c31..cd42f2d8 100644 --- a/graphene/relay/tests/test_connection_query.py +++ b/graphene/relay/tests/test_connection_query.py @@ -1,6 +1,8 @@ from collections import OrderedDict -from ..connection import ConnectionField +from promise import Promise + +from ..connection import ConnectionField, Connection, PageInfo from ..node import Node from graphql_relay.utils import base64 from ...types import ObjectType, String, Schema @@ -15,12 +17,34 @@ class Letter(ObjectType): letter = String() +class MyLetterObjectConnection(Connection): + extra = String() + + class Meta: + node = Letter + + class Edge: + other = String() + + class Query(ObjectType): letters = ConnectionField(Letter) + letters_promise = ConnectionField(Letter) + letters_connection = ConnectionField(MyLetterObjectConnection) - def resolve_letters(self, args, context, info): + def resolve_letters(self, *_): return list(letters.values()) + def resolve_letters_connection(self, *_): + return MyLetterObjectConnection( + extra='1', + page_info=PageInfo(has_next_page=True, has_previous_page=False), + edges=[MyLetterObjectConnection.Edge(cursor='1', node=Letter(letter='hello'))] + ) + + def resolve_letters_promise(self, *_): + return Promise.resolve(list(letters.values())) + node = Node.Field() @@ -75,8 +99,7 @@ def execute(args=''): ''' % args) -def check(args, letters, has_previous_page=False, has_next_page=False): - result = execute(args) +def create_expexted_result(letters, has_previous_page=False, has_next_page=False, field_name='letters'): expected_edges = edges(letters) expected_page_info = { 'hasPreviousPage': has_previous_page, @@ -84,16 +107,91 @@ def check(args, letters, has_previous_page=False, has_next_page=False): 'endCursor': expected_edges[-1]['cursor'] if expected_edges else None, 'startCursor': expected_edges[0]['cursor'] if expected_edges else None } - - assert not result.errors - assert result.data == { - 'letters': { + return { + field_name: { 'edges': expected_edges, 'pageInfo': expected_page_info } } +def check(args, letters, has_previous_page=False, has_next_page=False): + result = execute(args) + assert not result.errors + assert result.data == create_expexted_result(letters, has_previous_page, has_next_page) + + +def test_resolver_handles_returned_connection_field_correctly(): + result = schema.execute(''' + { + lettersConnection { + extra + edges { + node { + id + letter + } + cursor + } + pageInfo { + hasPreviousPage + hasNextPage + startCursor + endCursor + } + } + } + ''') + + assert not result.errors + expected_result = { + 'lettersConnection': { + 'extra': '1', + 'edges': [ + { + 'node': { + 'id': 'TGV0dGVyOk5vbmU=', + 'letter': 'hello', + }, + 'cursor': '1' + } + ], + 'pageInfo': { + 'hasPreviousPage': False, + 'hasNextPage': True, + 'startCursor': None, + 'endCursor': None, + } + } + } + assert result.data == expected_result + + +def test_resolver_handles_returned_promise_correctly(): + result = schema.execute(''' + { + lettersPromise { + edges { + node { + id + letter + } + cursor + } + pageInfo { + hasPreviousPage + hasNextPage + startCursor + endCursor + } + } + } + ''') + + assert not result.errors + assert result.data == create_expexted_result('ABCDE', field_name='lettersPromise') + + def test_returns_all_elements_without_filters(): check('', 'ABCDE')