From c7929234294a181993158b4439c24006e9ed1ebd Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 1 Oct 2016 10:42:06 -0700 Subject: [PATCH] Added ability to return a Connection instance in the connection resolver --- graphene/relay/connection.py | 20 ++++--- graphene/relay/tests/test_connection_query.py | 56 ++++++++++++++++++- 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index a1039e43..67ccea50 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -117,21 +117,25 @@ class IterableConnectionField(Field): ).format(str(self), connection_type) return connection_type - @staticmethod - def connection_resolver(resolver, connection, root, args, context, info): - iterable = resolver(root, args, context, info) - assert isinstance(iterable, Iterable), ( - 'Resolved value from the connection field have to be iterable. ' + @classmethod + def connection_resolver(cls, resolver, connection, root, args, context, info): + resolved = resolver(root, args, context, info) + + if isinstance(resolved, connection): + return resolved + + assert isinstance(resolved, Iterable), ( + 'Resolved value from the connection field have to be iterable or instance of {}. ' 'Received "{}"' - ).format(iterable) + ).format(connection, resolved) connection = connection_from_list( - iterable, + resolved, args, connection_type=connection, edge_type=connection.Edge, pageinfo_type=PageInfo ) - connection.iterable = iterable + connection.iterable = resolved return connection def get_resolver(self, parent_resolver): diff --git a/graphene/relay/tests/test_connection_query.py b/graphene/relay/tests/test_connection_query.py index 7a197f27..cc1f12ce 100644 --- a/graphene/relay/tests/test_connection_query.py +++ b/graphene/relay/tests/test_connection_query.py @@ -3,7 +3,7 @@ from collections import OrderedDict from graphql_relay.utils import base64 from ...types import ObjectType, Schema, String -from ..connection import ConnectionField +from ..connection import ConnectionField, PageInfo from ..node import Node letter_chars = ['A', 'B', 'C', 'D', 'E'] @@ -19,11 +19,26 @@ class Letter(ObjectType): class Query(ObjectType): letters = ConnectionField(Letter) + connection_letters = ConnectionField(Letter) + + node = Node.Field() def resolve_letters(self, args, context, info): return list(letters.values()) - node = Node.Field() + def resolve_connection_letters(self, args, context, info): + return Letter.Connection( + page_info=PageInfo( + has_next_page=True, + has_previous_page=False + ), + edges=[ + Letter.Connection.Edge( + node=Letter(id=0, letter='A'), + cursor='a-cursor' + ), + ] + ) schema = Schema(Query) @@ -176,3 +191,40 @@ def test_returns_all_elements_if_cursors_are_on_the_outside(): def test_returns_no_elements_if_cursors_cross(): check('before: "{}" after: "{}"'.format(base64('arrayconnection:%s' % 2), base64('arrayconnection:%s' % 4)), '') + + +def test_connection_type_nodes(): + result = schema.execute(''' + { + connectionLetters { + edges { + node { + id + letter + } + cursor + } + pageInfo { + hasPreviousPage + hasNextPage + } + } + } + ''') + + assert not result.errors + assert result.data == { + 'connectionLetters': { + 'edges': [{ + 'node': { + 'id': 'TGV0dGVyOjA=', + 'letter': 'A', + }, + 'cursor': 'a-cursor', + }], + 'pageInfo': { + 'hasPreviousPage': False, + 'hasNextPage': True, + } + } + }