diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index d93cdcf8..2942a3a9 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -1,4 +1,5 @@ import six +import functools from graphql_relay.node.node import from_global_id @@ -9,6 +10,10 @@ from ..utils.wrap_resolver_function import has_context, with_context from .connection import Connection, Edge +def _is_thenable(obj): + return callable(getattr(obj, "then", None)) + + class ConnectionField(Field): def __init__(self, type, resolver=None, description='', @@ -27,6 +32,11 @@ class ConnectionField(Field): self.connection_type = connection_type or Connection self.edge_type = edge_type or Edge + def _get_connection_type(self, connection_type, args, context, info, resolved): + if isinstance(resolved, self.connection_type): + return resolved + return self.from_list(connection_type, resolved, args, context, info) + @with_context def resolver(self, instance, args, context, info): schema = info.schema.graphene_schema @@ -38,9 +48,12 @@ class ConnectionField(Field): else: resolved = super(ConnectionField, self).resolver(instance, args, info) - if isinstance(resolved, self.connection_type): - return resolved - return self.from_list(connection_type, resolved, args, context, info) + get_connection_type = functools.partial(self._get_connection_type, connection_type, args, context, info) + + if _is_thenable(resolved): + return resolved.then(get_connection_type) + + return get_connection_type(resolved) def from_list(self, connection_type, resolved, args, context, info): return connection_type.from_list(resolved, args, context, info) diff --git a/graphene/relay/tests/test_query.py b/graphene/relay/tests/test_query.py index 1603709e..af6e188a 100644 --- a/graphene/relay/tests/test_query.py +++ b/graphene/relay/tests/test_query.py @@ -4,6 +4,8 @@ from graphql.type import GraphQLID, GraphQLNonNull import graphene from graphene import relay, with_context +from promise import Promise + schema = graphene.Schema() @@ -52,6 +54,9 @@ class Query(graphene.ObjectType): connection_type_nodes = relay.ConnectionField( MyNode, connection_type=MyConnection) + promise_connection_type = relay.ConnectionField( + MyNode, connection_type=MyConnection) + all_my_objects = relay.ConnectionField( MyObject, connection_type=MyConnection) @@ -76,6 +81,9 @@ class Query(graphene.ObjectType): def resolve_all_my_objects(self, args, info): return [MyObject(name='my_object')] + def resolve_promise_connection_type(self, args, info): + return Promise.resolve('async name').then(lambda name: [MyNode(id='1', name=name)]) + schema.query = Query @@ -228,6 +236,32 @@ def test_connectionfield_resolve_returning_objects(): assert result.data == expected +def test_connectionfield_resolve_returning_promise(): + query = ''' + query RebelsShipsQuery { + promiseConnectionType { + edges { + node { + name + } + } + } + } + ''' + expected = { + 'promiseConnectionType': { + 'edges': [{ + 'node': { + 'name': 'async name' + } + }] + } + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + @pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')]) def test_get_node_info(specialness, value): query = ''' diff --git a/setup.py b/setup.py index 27246df5..8f1016a4 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,7 @@ setup( 'sqlalchemy', 'sqlalchemy_utils', 'mock', + 'promse', # Required for Django postgres fields testing 'psycopg2', ], diff --git a/tox.ini b/tox.ini index 05b37dfd..255e3959 100644 --- a/tox.ini +++ b/tox.ini @@ -14,6 +14,7 @@ deps= blinker singledispatch mock + promise setenv = PYTHONPATH = .:{envdir} commands=