Extend DjangoListField to always pass queryset to get_queryset method

This commit is contained in:
Jonathan Kim 2020-04-12 14:32:07 +01:00
parent e1cfc0a80b
commit 76460c83a2
2 changed files with 134 additions and 11 deletions

View File

@ -38,16 +38,21 @@ class DjangoListField(Field):
def model(self): def model(self):
return self._underlying_type._meta.model return self._underlying_type._meta.model
def get_default_queryset(self):
return self.model._default_manager.get_queryset()
@staticmethod @staticmethod
def list_resolver(django_object_type, resolver, root, info, **args): def list_resolver(
django_object_type, resolver, default_queryset, root, info, **args
):
queryset = maybe_queryset(resolver(root, info, **args)) queryset = maybe_queryset(resolver(root, info, **args))
if queryset is None: if queryset is None:
# Default to Django Model queryset queryset = default_queryset
# N.B. This happens if DjangoListField is used in the top level Query object
model_manager = django_object_type._meta.model.objects if isinstance(queryset, QuerySet):
queryset = maybe_queryset( # Pass queryset to the DjangoObjectType get_queryset method
django_object_type.get_queryset(model_manager, info) queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
)
return queryset return queryset
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
@ -55,7 +60,12 @@ class DjangoListField(Field):
if isinstance(_type, NonNull): if isinstance(_type, NonNull):
_type = _type.of_type _type = _type.of_type
django_object_type = _type.of_type.of_type django_object_type = _type.of_type.of_type
return partial(self.list_resolver, django_object_type, parent_resolver) return partial(
self.list_resolver,
django_object_type,
parent_resolver,
self.get_default_queryset(),
)
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):

View File

@ -1,4 +1,5 @@
import datetime import datetime
from django.db.models import Count
import pytest import pytest
@ -142,13 +143,26 @@ class TestDjangoListField:
pub_date_time=datetime.datetime.now(), pub_date_time=datetime.datetime.now(),
editor=r1, editor=r1,
) )
ArticleModel.objects.create(
headline="Not so good news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"reporters": [ "reporters": [
{"firstName": "Tara", "articles": [{"headline": "Amazing news"}]}, {
"firstName": "Tara",
"articles": [
{"headline": "Amazing news"},
{"headline": "Not so good news"},
],
},
{"firstName": "Debra", "articles": []}, {"firstName": "Debra", "articles": []},
] ]
} }
@ -164,8 +178,8 @@ class TestDjangoListField:
model = ReporterModel model = ReporterModel
fields = ("first_name", "articles") fields = ("first_name", "articles")
def resolve_reporters(reporter, info): def resolve_articles(reporter, info):
return reporter.articles.all() return reporter.articles.filter(headline__contains="Amazing")
class Query(ObjectType): class Query(ObjectType):
reporters = DjangoListField(Reporter) reporters = DjangoListField(Reporter)
@ -193,6 +207,13 @@ class TestDjangoListField:
pub_date_time=datetime.datetime.now(), pub_date_time=datetime.datetime.now(),
editor=r1, editor=r1,
) )
ArticleModel.objects.create(
headline="Not so good news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query) result = schema.execute(query)
@ -203,3 +224,95 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []}, {"firstName": "Debra", "articles": []},
] ]
} }
def test_get_queryset_filter(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return ReporterModel.objects.all()
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"},]}
def test_resolve_list(self):
"""Resolving a plain list should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"},]}