diff --git a/graphene_django/fields.py b/graphene_django/fields.py index fb6b98a..7539cf2 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -38,16 +38,21 @@ class DjangoListField(Field): def model(self): return self._underlying_type._meta.model + def get_default_queryset(self): + return self.model._default_manager.get_queryset() + @staticmethod - def list_resolver(django_object_type, resolver, root, info, **args): + def list_resolver( + django_object_type, resolver, default_queryset, root, info, **args + ): queryset = maybe_queryset(resolver(root, info, **args)) if queryset is None: - # Default to Django Model queryset - # N.B. This happens if DjangoListField is used in the top level Query object - model_manager = django_object_type._meta.model.objects - queryset = maybe_queryset( - django_object_type.get_queryset(model_manager, info) - ) + queryset = default_queryset + + if isinstance(queryset, QuerySet): + # Pass queryset to the DjangoObjectType get_queryset method + queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) + return queryset def get_resolver(self, parent_resolver): @@ -55,7 +60,12 @@ class DjangoListField(Field): if isinstance(_type, NonNull): _type = _type.of_type django_object_type = _type.of_type.of_type - return partial(self.list_resolver, django_object_type, parent_resolver) + return partial( + self.list_resolver, + django_object_type, + parent_resolver, + self.get_default_queryset(), + ) class DjangoConnectionField(ConnectionField): diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index 8ea1901..39b82ab 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -1,4 +1,5 @@ import datetime +from django.db.models import Count import pytest @@ -142,13 +143,26 @@ class TestDjangoListField: pub_date_time=datetime.datetime.now(), editor=r1, ) + ArticleModel.objects.create( + headline="Not so good news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) result = schema.execute(query) assert not result.errors assert result.data == { "reporters": [ - {"firstName": "Tara", "articles": [{"headline": "Amazing news"}]}, + { + "firstName": "Tara", + "articles": [ + {"headline": "Amazing news"}, + {"headline": "Not so good news"}, + ], + }, {"firstName": "Debra", "articles": []}, ] } @@ -164,8 +178,8 @@ class TestDjangoListField: model = ReporterModel fields = ("first_name", "articles") - def resolve_reporters(reporter, info): - return reporter.articles.all() + def resolve_articles(reporter, info): + return reporter.articles.filter(headline__contains="Amazing") class Query(ObjectType): reporters = DjangoListField(Reporter) @@ -193,6 +207,13 @@ class TestDjangoListField: pub_date_time=datetime.datetime.now(), editor=r1, ) + ArticleModel.objects.create( + headline="Not so good news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) result = schema.execute(query) @@ -203,3 +224,95 @@ class TestDjangoListField: {"firstName": "Debra", "articles": []}, ] } + + def test_get_queryset_filter(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + def resolve_reporters(_, info): + return ReporterModel.objects.all() + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Tara"},]} + + def test_resolve_list(self): + """Resolving a plain list should work (and not call get_queryset)""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + def resolve_reporters(_, info): + return [ReporterModel.objects.get(first_name="Debra")] + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Debra"},]}