diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 7ea30ee..ea07298 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -68,18 +68,8 @@ class DjangoConnectionField(ConnectionField): return self.model._default_manager @classmethod - def resolve_queryset(cls, connection, queryset, info): - default_queryset = maybe_queryset(connection._meta.node.get_queryset(info)) - if queryset: - return cls.merge_querysets(default_queryset, queryset) - return default_queryset - - def get_queryset_resolver(self): - return partial( - self.resolve_queryset, - self.type, - maybe_queryset(getattr(self.model, self.on)) if self.on else None, - ) + def resolve_queryset(cls, connection, queryset, info, args): + return connection._meta.node.get_queryset(queryset, info) @classmethod def merge_querysets(cls, default_queryset, queryset): @@ -120,7 +110,7 @@ class DjangoConnectionField(ConnectionField): cls, resolver, connection, - queryset_resolver, + default_manager, max_limit, enforce_first_or_last, root, @@ -149,13 +139,7 @@ class DjangoConnectionField(ConnectionField): args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) - if callable(queryset_resolver): - queryset = queryset_resolver(info) - else: - assert isinstance(queryset_resolver, QuerySet), "The type {} is not a QuerySet".format( - type(queryset_resolver) - ) - queryset = queryset_resolver + queryset = cls.resolve_queryset(connection, default_manager, info, args) on_resolve = partial(cls.resolve_connection, connection, queryset, args) if Promise.is_thenable(iterable): @@ -168,7 +152,7 @@ class DjangoConnectionField(ConnectionField): self.connection_resolver, parent_resolver, self.type, - self.get_queryset_resolver(), + self.get_manager(), self.max_limit, self.enforce_first_or_last, ) diff --git a/graphene_django/types.py b/graphene_django/types.py index a111114..8e35608 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -128,12 +128,13 @@ class DjangoObjectType(ObjectType): return model == cls._meta.model @classmethod - def get_queryset(cls, info): - return cls._meta.model.objects + def get_queryset(cls, queryset, info): + return queryset @classmethod def get_node(cls, info, id): + queryset = cls.get_queryset(cls._meta.model.objects, info) try: - return cls.get_queryset(info).get(pk=id) + return queryset.get(pk=id) except cls._meta.model.DoesNotExist: return None