diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 2125d16..f82e4b2 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -62,31 +62,7 @@ class DjangoConnectionField(ConnectionField): return default_queryset & queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, args, context, info): - first = args.get('first') - last = args.get('last') - - if enforce_first_or_last: - assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' - ).format(info.field_name) - - if max_limit: - if first: - assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) - - if last: - assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['last'] = min(last, max_limit) - - iterable = resolver(root, args, context, info) - iterable = Promise.resolve(iterable).get() + def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -111,6 +87,38 @@ class DjangoConnectionField(ConnectionField): connection.length = _len return connection + @classmethod + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, args, context, info): + first = args.get('first') + last = args.get('last') + + if enforce_first_or_last: + assert first or last, ( + 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + ).format(info.field_name) + + if max_limit: + if first: + assert first <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['first'] = min(first, max_limit) + + if last: + assert last <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['last'] = min(last, max_limit) + + iterable = resolver(root, args, context, info) + on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) + def get_resolver(self, parent_resolver): return partial( self.connection_resolver,