diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index e63478e5..20f3cceb 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -5,6 +5,7 @@ from functools import partial import six from graphql_relay import connection_from_list +from promise import Promise from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union) @@ -118,26 +119,34 @@ class IterableConnectionField(Field): return connection_type @classmethod - def connection_resolver(cls, resolver, connection, root, args, context, info): - resolved = resolver(root, args, context, info) - - if isinstance(resolved, connection): + def resolve_connection(cls, connection_type, args, resolved): + if isinstance(resolved, connection_type): return resolved assert isinstance(resolved, Iterable), ( 'Resolved value from the connection field have to be iterable or instance of {}. ' 'Received "{}"' - ).format(connection, resolved) + ).format(connection_type, resolved) connection = connection_from_list( resolved, args, - connection_type=connection, - edge_type=connection.Edge, + connection_type=connection_type, + edge_type=connection_type.Edge, pageinfo_type=PageInfo ) connection.iterable = resolved return connection + @classmethod + def connection_resolver(cls, resolver, connection_type, root, args, context, info): + resolved = resolver(root, args, context, info) + + on_resolve = partial(cls.resolve_connection, connection_type, args) + if isinstance(resolved, Promise): + return resolved.then(on_resolve) + + return on_resolve(resolved) + def get_resolver(self, parent_resolver): resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) return partial(self.connection_resolver, resolver, self.type)