This commit is contained in:
Markus Padourek 2016-08-22 06:35:45 +00:00 committed by GitHub
commit 40256e7aeb
4 changed files with 52 additions and 3 deletions

View File

@ -1,4 +1,5 @@
import six import six
import functools
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id
@ -9,6 +10,10 @@ from ..utils.wrap_resolver_function import has_context, with_context
from .connection import Connection, Edge from .connection import Connection, Edge
def _is_thenable(obj):
return callable(getattr(obj, "then", None))
class ConnectionField(Field): class ConnectionField(Field):
def __init__(self, type, resolver=None, description='', def __init__(self, type, resolver=None, description='',
@ -27,6 +32,11 @@ class ConnectionField(Field):
self.connection_type = connection_type or Connection self.connection_type = connection_type or Connection
self.edge_type = edge_type or Edge self.edge_type = edge_type or Edge
def _get_connection_type(self, connection_type, args, context, info, resolved):
if isinstance(resolved, self.connection_type):
return resolved
return self.from_list(connection_type, resolved, args, context, info)
@with_context @with_context
def resolver(self, instance, args, context, info): def resolver(self, instance, args, context, info):
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
@ -38,9 +48,12 @@ class ConnectionField(Field):
else: else:
resolved = super(ConnectionField, self).resolver(instance, args, info) resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, self.connection_type): get_connection_type = functools.partial(self._get_connection_type, connection_type, args, context, info)
return resolved
return self.from_list(connection_type, resolved, args, context, info) if _is_thenable(resolved):
return resolved.then(get_connection_type)
return get_connection_type(resolved)
def from_list(self, connection_type, resolved, args, context, info): def from_list(self, connection_type, resolved, args, context, info):
return connection_type.from_list(resolved, args, context, info) return connection_type.from_list(resolved, args, context, info)

View File

@ -4,6 +4,8 @@ from graphql.type import GraphQLID, GraphQLNonNull
import graphene import graphene
from graphene import relay, with_context from graphene import relay, with_context
from promise import Promise
schema = graphene.Schema() schema = graphene.Schema()
@ -52,6 +54,9 @@ class Query(graphene.ObjectType):
connection_type_nodes = relay.ConnectionField( connection_type_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection) MyNode, connection_type=MyConnection)
promise_connection_type = relay.ConnectionField(
MyNode, connection_type=MyConnection)
all_my_objects = relay.ConnectionField( all_my_objects = relay.ConnectionField(
MyObject, connection_type=MyConnection) MyObject, connection_type=MyConnection)
@ -76,6 +81,9 @@ class Query(graphene.ObjectType):
def resolve_all_my_objects(self, args, info): def resolve_all_my_objects(self, args, info):
return [MyObject(name='my_object')] return [MyObject(name='my_object')]
def resolve_promise_connection_type(self, args, info):
return Promise.resolve('async name').then(lambda name: [MyNode(id='1', name=name)])
schema.query = Query schema.query = Query
@ -228,6 +236,32 @@ def test_connectionfield_resolve_returning_objects():
assert result.data == expected assert result.data == expected
def test_connectionfield_resolve_returning_promise():
query = '''
query RebelsShipsQuery {
promiseConnectionType {
edges {
node {
name
}
}
}
}
'''
expected = {
'promiseConnectionType': {
'edges': [{
'node': {
'name': 'async name'
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')]) @pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
def test_get_node_info(specialness, value): def test_get_node_info(specialness, value):
query = ''' query = '''

View File

@ -66,6 +66,7 @@ setup(
'sqlalchemy', 'sqlalchemy',
'sqlalchemy_utils', 'sqlalchemy_utils',
'mock', 'mock',
'promse',
# Required for Django postgres fields testing # Required for Django postgres fields testing
'psycopg2', 'psycopg2',
], ],

View File

@ -14,6 +14,7 @@ deps=
blinker blinker
singledispatch singledispatch
mock mock
promise
setenv = setenv =
PYTHONPATH = .:{envdir} PYTHONPATH = .:{envdir}
commands= commands=