From 47ca1ad607af21545a896e1acd1c169decfc748e Mon Sep 17 00:00:00 2001 From: Jonathan Kim Date: Sun, 4 Aug 2019 08:38:31 +0100 Subject: [PATCH] Only allow DjangoObjectTypes to DjangoListField --- graphene_django/fields.py | 16 ++++++++++++---- graphene_django/tests/test_fields.py | 22 +++++++++++++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 17086b3..c52df8e 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,13 +1,12 @@ from functools import partial from django.db.models.query import QuerySet -from graphene import NonNull - +from graphql_relay.connection.arrayconnection import connection_from_list_slice from promise import Promise -from graphene.types import Field, List +from graphene import NonNull from graphene.relay import ConnectionField, PageInfo -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphene.types import Field, List from .settings import graphene_settings from .utils import maybe_queryset @@ -15,6 +14,15 @@ from .utils import maybe_queryset class DjangoListField(Field): def __init__(self, _type, *args, **kwargs): + from .types import DjangoObjectType + + if isinstance(_type, NonNull): + _type = _type.of_type + + assert issubclass( + _type, DjangoObjectType + ), "DjangoListField only accepts DjangoObjectType types" + # Django would never return a Set of None vvvvvvv super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs) diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index 75f27da..2415d22 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -1,5 +1,6 @@ import pytest -from graphene import ObjectType, Schema + +from graphene import List, NonNull, ObjectType, Schema, String from ..fields import DjangoListField from ..types import DjangoObjectType @@ -8,6 +9,25 @@ from .models import Reporter as ReporterModel @pytest.mark.django_db class TestDjangoListField: + def test_only_django_object_types(self): + class TestType(ObjectType): + foo = String() + + with pytest.raises(AssertionError): + list_field = DjangoListField(TestType) + + def test_non_null_type(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + list_field = DjangoListField(NonNull(Reporter)) + + assert isinstance(list_field.type, List) + assert isinstance(list_field.type.of_type, NonNull) + assert list_field.type.of_type.of_type is Reporter + def test_get_django_model(self): class Reporter(DjangoObjectType): class Meta: