This commit is contained in:
Markus Padourek 2016-07-28 11:00:38 +00:00 committed by GitHub
commit 4df88073fa
3 changed files with 64 additions and 11 deletions

View File

@ -29,7 +29,7 @@ class ConnectionField(Field):
@with_context
def resolver(self, instance, args, context, info):
schema = info.schema.graphene_schema
connection_type = self.get_type(schema)
connection_type_for_node = self.get_type(schema)
resolver = super(ConnectionField, self).resolver
if has_context(resolver):
@ -37,17 +37,18 @@ class ConnectionField(Field):
else:
resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, connection_type):
if isinstance(resolved, self.connection_type):
return resolved
return self.from_list(connection_type, resolved, args, context, info)
return self.from_list(connection_type_for_node, resolved, args, context, info)
def from_list(self, connection_type, resolved, args, context, info):
return connection_type.from_list(resolved, args, context, info)
def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type()
edge_type = self.get_edge_type(node)
return connection_type.for_node(node, edge_type=edge_type)
self.connection_type = connection_type
return connection_type.for_node(node)
def get_edge_type(self, node):
edge_type = self.edge_type or node.get_edge_type()
@ -59,9 +60,9 @@ class ConnectionField(Field):
node = schema.objecttype(type)
assert is_node(node), 'Only nodes have connections.'
schema.register(node)
connection_type = self.get_connection_type(node)
connection_type_for_node = self.get_connection_type(node)
return connection_type
return connection_type_for_node
class NodeField(Field):

View File

@ -40,6 +40,8 @@ class Query(graphene.ObjectType):
context_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String())
sliced_nodes = relay.ConnectionField(MyNode)
def resolve_all_my_nodes(self, args, info):
custom_arg = args.get('customArg')
assert custom_arg == "1"
@ -51,6 +53,11 @@ class Query(graphene.ObjectType):
assert custom_arg == "1"
return [MyNode(name='my')]
def resolve_sliced_nodes(self, args, info):
sliced_list = [MyNode(name='my1'), MyNode(name='my2'), MyNode(name='my3')]
total_count = 10
return relay.Connection.for_node(MyNode).from_list(sliced_list, args, None, info, total_count)
schema.query = Query
@ -135,6 +142,46 @@ def test_connectionfield_context_query():
assert result.data == expected
def test_slice_connectionfield_query():
query = '''
query RebelsShipsQuery {
slicedNodes (first: 3) {
edges {
node {
name
}
},
pageInfo {
hasNextPage
}
}
}
'''
expected = {
'slicedNodes': {
'edges': [{
'node': {
'name': 'my1'
}
}, {
'node': {
'name': 'my2'
}
}, {
'node': {
'name': 'my3'
}
}],
'pageInfo': {
'hasNextPage': True,
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
def test_get_node_info(specialness, value):
query = '''

View File

@ -5,7 +5,7 @@ from functools import wraps
import six
from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from graphql_relay.node.node import to_global_id
from ..core.classtypes import InputObjectType, Interface, Mutation, ObjectType
@ -87,12 +87,17 @@ class Connection(ObjectType):
{'edge_type': edge_type, 'edges': edges})
@classmethod
def from_list(cls, iterable, args, context, info):
def from_list(cls, iterable, args, context, info, total_count=None):
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
list_slice_length = len(iterable)
list_length = total_count if total_count else list_slice_length
connection = connection_from_list_slice(
iterable, args, connection_type=cls,
edge_type=cls.edge_type, pageinfo_type=PageInfo)
edge_type=cls.edge_type, pageinfo_type=PageInfo,
list_length=list_length, list_slice_length=list_slice_length)
connection.set_connection_data(iterable)
return connection