diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 237cd07..7ea30ee 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -66,6 +66,20 @@ class DjangoConnectionField(ConnectionField): return getattr(self.model, self.on) else: return self.model._default_manager + + @classmethod + def resolve_queryset(cls, connection, queryset, info): + default_queryset = maybe_queryset(connection._meta.node.get_queryset(info)) + if queryset: + return cls.merge_querysets(default_queryset, queryset) + return default_queryset + + def get_queryset_resolver(self): + return partial( + self.resolve_queryset, + self.type, + maybe_queryset(getattr(self.model, self.on)) if self.on else None, + ) @classmethod def merge_querysets(cls, default_queryset, queryset): @@ -106,7 +120,7 @@ class DjangoConnectionField(ConnectionField): cls, resolver, connection, - default_manager, + queryset_resolver, max_limit, enforce_first_or_last, root, @@ -135,7 +149,13 @@ class DjangoConnectionField(ConnectionField): args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) - queryset = connection._meta.node.get_queryset(info) + if callable(queryset_resolver): + queryset = queryset_resolver(info) + else: + assert isinstance(queryset_resolver, QuerySet), "The type {} is not a QuerySet".format( + type(queryset_resolver) + ) + queryset = queryset_resolver on_resolve = partial(cls.resolve_connection, connection, queryset, args) if Promise.is_thenable(iterable): @@ -148,7 +168,7 @@ class DjangoConnectionField(ConnectionField): self.connection_resolver, parent_resolver, self.type, - self.get_manager(), + self.get_queryset_resolver(), self.max_limit, self.enforce_first_or_last, )