Only allow DjangoObjectTypes to DjangoListField

This commit is contained in:
Jonathan Kim 2019-08-04 08:38:31 +01:00
parent 609baec3a0
commit 47ca1ad607
2 changed files with 33 additions and 5 deletions

View File

@ -1,13 +1,12 @@
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphene import NonNull from graphql_relay.connection.arrayconnection import connection_from_list_slice
from promise import Promise from promise import Promise
from graphene.types import Field, List from graphene import NonNull
from graphene.relay import ConnectionField, PageInfo from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphene.types import Field, List
from .settings import graphene_settings from .settings import graphene_settings
from .utils import maybe_queryset from .utils import maybe_queryset
@ -15,6 +14,15 @@ from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs): def __init__(self, _type, *args, **kwargs):
from .types import DjangoObjectType
if isinstance(_type, NonNull):
_type = _type.of_type
assert issubclass(
_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types"
# Django would never return a Set of None vvvvvvv # Django would never return a Set of None vvvvvvv
super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs) super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)

View File

@ -1,5 +1,6 @@
import pytest import pytest
from graphene import ObjectType, Schema
from graphene import List, NonNull, ObjectType, Schema, String
from ..fields import DjangoListField from ..fields import DjangoListField
from ..types import DjangoObjectType from ..types import DjangoObjectType
@ -8,6 +9,25 @@ from .models import Reporter as ReporterModel
@pytest.mark.django_db @pytest.mark.django_db
class TestDjangoListField: class TestDjangoListField:
def test_only_django_object_types(self):
class TestType(ObjectType):
foo = String()
with pytest.raises(AssertionError):
list_field = DjangoListField(TestType)
def test_non_null_type(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
list_field = DjangoListField(NonNull(Reporter))
assert isinstance(list_field.type, List)
assert isinstance(list_field.type.of_type, NonNull)
assert list_field.type.of_type.of_type is Reporter
def test_get_django_model(self): def test_get_django_model(self):
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta: