diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index e79b9592..7f556575 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -29,7 +29,7 @@ class ConnectionField(Field): @with_context def resolver(self, instance, args, context, info): schema = info.schema.graphene_schema - connection_type = self.get_type(schema) + connection_type_for_node = self.get_type(schema) resolver = super(ConnectionField, self).resolver if has_context(resolver): @@ -37,17 +37,18 @@ class ConnectionField(Field): else: resolved = super(ConnectionField, self).resolver(instance, args, info) - if isinstance(resolved, connection_type): + if isinstance(resolved, self.connection_type): return resolved - return self.from_list(connection_type, resolved, args, context, info) + + return self.from_list(connection_type_for_node, resolved, args, context, info) def from_list(self, connection_type, resolved, args, context, info): return connection_type.from_list(resolved, args, context, info) def get_connection_type(self, node): connection_type = self.connection_type or node.get_connection_type() - edge_type = self.get_edge_type(node) - return connection_type.for_node(node, edge_type=edge_type) + self.connection_type = connection_type + return connection_type.for_node(node) def get_edge_type(self, node): edge_type = self.edge_type or node.get_edge_type() @@ -59,9 +60,9 @@ class ConnectionField(Field): node = schema.objecttype(type) assert is_node(node), 'Only nodes have connections.' schema.register(node) - connection_type = self.get_connection_type(node) + connection_type_for_node = self.get_connection_type(node) - return connection_type + return connection_type_for_node class NodeField(Field): diff --git a/graphene/relay/tests/test_query.py b/graphene/relay/tests/test_query.py index 08e57226..d982aa8d 100644 --- a/graphene/relay/tests/test_query.py +++ b/graphene/relay/tests/test_query.py @@ -40,6 +40,8 @@ class Query(graphene.ObjectType): context_nodes = relay.ConnectionField( MyNode, connection_type=MyConnection, customArg=graphene.String()) + sliced_nodes = relay.ConnectionField(MyNode) + def resolve_all_my_nodes(self, args, info): custom_arg = args.get('customArg') assert custom_arg == "1" @@ -51,6 +53,11 @@ class Query(graphene.ObjectType): assert custom_arg == "1" return [MyNode(name='my')] + def resolve_sliced_nodes(self, args, info): + sliced_list = [MyNode(name='my1'), MyNode(name='my2'), MyNode(name='my3')] + total_count = 10 + return relay.Connection.for_node(MyNode).from_list(sliced_list, args, None, info, total_count) + schema.query = Query @@ -135,6 +142,46 @@ def test_connectionfield_context_query(): assert result.data == expected +def test_slice_connectionfield_query(): + query = ''' + query RebelsShipsQuery { + slicedNodes (first: 3) { + edges { + node { + name + } + }, + pageInfo { + hasNextPage + } + } + } + ''' + expected = { + 'slicedNodes': { + 'edges': [{ + 'node': { + 'name': 'my1' + } + }, { + 'node': { + 'name': 'my2' + } + }, { + 'node': { + 'name': 'my3' + } + }], + 'pageInfo': { + 'hasNextPage': True, + } + } + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + @pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')]) def test_get_node_info(specialness, value): query = ''' diff --git a/graphene/relay/types.py b/graphene/relay/types.py index 3ab55770..ced0140c 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -5,7 +5,7 @@ from functools import wraps import six -from graphql_relay.connection.arrayconnection import connection_from_list +from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphql_relay.node.node import to_global_id from ..core.classtypes import InputObjectType, Interface, Mutation, ObjectType @@ -87,12 +87,17 @@ class Connection(ObjectType): {'edge_type': edge_type, 'edges': edges}) @classmethod - def from_list(cls, iterable, args, context, info): + def from_list(cls, iterable, args, context, info, total_count=None): assert isinstance( iterable, Iterable), 'Resolved value from the connection field have to be iterable' - connection = connection_from_list( + + list_slice_length = len(iterable) + list_length = total_count if total_count else list_slice_length + + connection = connection_from_list_slice( iterable, args, connection_type=cls, - edge_type=cls.edge_type, pageinfo_type=PageInfo) + edge_type=cls.edge_type, pageinfo_type=PageInfo, + list_length=list_length, list_slice_length=list_slice_length) connection.set_connection_data(iterable) return connection