Implicitly mark Django fields as NonNull where appropriate

This commit is contained in:
Conrad Kramer 2016-07-06 19:14:30 -07:00
parent 8711dd20b6
commit 0c63b0c5e2
3 changed files with 40 additions and 9 deletions

View File

@ -3,7 +3,7 @@ from django.utils.encoding import force_text
from ...core.classtypes.enum import Enum from ...core.classtypes.enum import Enum
from ...core.types.custom_scalars import DateTime, JSONString from ...core.types.custom_scalars import DateTime, JSONString
from ...core.types.definitions import List from ...core.types.definitions import NonNull, List
from ...core.types.scalars import ID, Boolean, Float, Int, String from ...core.types.scalars import ID, Boolean, Float, Int, String
from ...utils import to_const from ...utils import to_const
from .compat import (ArrayField, HStoreField, JSONField, RangeField, from .compat import (ArrayField, HStoreField, JSONField, RangeField,
@ -32,6 +32,17 @@ def convert_django_field_with_choices(field):
return convert_django_field(field) return convert_django_field(field)
def add_nonnull_to_field(convert_field):
def convert_django_nonnull_field(field):
graphene_type = convert_field(field)
if isinstance(field, models.ManyToOneRel):
is_null = field.field.null
else:
is_null = field.null
return graphene_type if is_null else NonNull(graphene_type)
return convert_django_nonnull_field
@singledispatch @singledispatch
def convert_django_field(field): def convert_django_field(field):
raise Exception( raise Exception(
@ -47,11 +58,13 @@ def convert_django_field(field):
@convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField) @convert_django_field.register(models.FileField)
@convert_django_field.register(UUIDField) @convert_django_field.register(UUIDField)
@add_nonnull_to_field
def convert_field_to_string(field): def convert_field_to_string(field):
return String(description=field.help_text) return String(description=field.help_text)
@convert_django_field.register(models.AutoField) @convert_django_field.register(models.AutoField)
@add_nonnull_to_field
def convert_field_to_id(field): def convert_field_to_id(field):
return ID(description=field.help_text) return ID(description=field.help_text)
@ -61,11 +74,13 @@ def convert_field_to_id(field):
@convert_django_field.register(models.SmallIntegerField) @convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField) @convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.IntegerField) @convert_django_field.register(models.IntegerField)
@add_nonnull_to_field
def convert_field_to_int(field): def convert_field_to_int(field):
return Int(description=field.help_text) return Int(description=field.help_text)
@convert_django_field.register(models.BooleanField) @convert_django_field.register(models.BooleanField)
@add_nonnull_to_field
def convert_field_to_boolean(field): def convert_field_to_boolean(field):
return Boolean(description=field.help_text, required=True) return Boolean(description=field.help_text, required=True)
@ -77,16 +92,19 @@ def convert_field_to_nullboolean(field):
@convert_django_field.register(models.DecimalField) @convert_django_field.register(models.DecimalField)
@convert_django_field.register(models.FloatField) @convert_django_field.register(models.FloatField)
@add_nonnull_to_field
def convert_field_to_float(field): def convert_field_to_float(field):
return Float(description=field.help_text) return Float(description=field.help_text)
@convert_django_field.register(models.DateField) @convert_django_field.register(models.DateField)
@add_nonnull_to_field
def convert_date_to_string(field): def convert_date_to_string(field):
return DateTime(description=field.help_text) return DateTime(description=field.help_text)
@convert_django_field.register(models.OneToOneRel) @convert_django_field.register(models.OneToOneRel)
@add_nonnull_to_field
def convert_onetoone_field_to_djangomodel(field): def convert_onetoone_field_to_djangomodel(field):
from .fields import DjangoModelField from .fields import DjangoModelField
return DjangoModelField(get_related_model(field)) return DjangoModelField(get_related_model(field))
@ -107,12 +125,13 @@ def convert_relatedfield_to_djangomodel(field):
from .fields import DjangoModelField, ConnectionOrListField from .fields import DjangoModelField, ConnectionOrListField
model_field = DjangoModelField(field.model) model_field = DjangoModelField(field.model)
if isinstance(field.field, models.OneToOneField): if isinstance(field.field, models.OneToOneField):
return model_field return model_field if field.field.null else NonNull(model_field)
return ConnectionOrListField(model_field) return ConnectionOrListField(model_field)
@convert_django_field.register(models.OneToOneField) @convert_django_field.register(models.OneToOneField)
@convert_django_field.register(models.ForeignKey) @convert_django_field.register(models.ForeignKey)
@add_nonnull_to_field
def convert_field_to_djangomodel(field): def convert_field_to_djangomodel(field):
from .fields import DjangoModelField from .fields import DjangoModelField
return DjangoModelField(get_related_model(field), description=field.help_text) return DjangoModelField(get_related_model(field), description=field.help_text)
@ -126,11 +145,13 @@ def convert_postgres_array_to_list(field):
@convert_django_field.register(HStoreField) @convert_django_field.register(HStoreField)
@convert_django_field.register(JSONField) @convert_django_field.register(JSONField)
@add_nonnull_to_field
def convert_posgres_field_to_string(field): def convert_posgres_field_to_string(field):
return JSONString(description=field.help_text) return JSONString(description=field.help_text)
@convert_django_field.register(RangeField) @convert_django_field.register(RangeField)
@add_nonnull_to_field
def convert_posgres_range_to_string(field): def convert_posgres_range_to_string(field):
inner_type = convert_django_field(field.base_field) inner_type = convert_django_field(field.base_field)
return List(inner_type, description=field.help_text) return List(inner_type, description=field.help_text)

View File

@ -35,7 +35,7 @@ class Reporter(models.Model):
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100, null=True)
pub_date = models.DateField() pub_date = models.DateField()
reporter = models.ForeignKey(Reporter, related_name='articles') reporter = models.ForeignKey(Reporter, related_name='articles')
lang = models.CharField(max_length=2, help_text='Language', choices=[ lang = models.CharField(max_length=2, help_text='Language', choices=[

View File

@ -5,6 +5,7 @@ from py.test import raises
import graphene import graphene
from graphene.core.types.custom_scalars import DateTime, JSONString from graphene.core.types.custom_scalars import DateTime, JSONString
from graphene.core.types.definitions import OfType
from ..compat import (ArrayField, HStoreField, JSONField, MissingType, from ..compat import (ArrayField, HStoreField, JSONField, MissingType,
RangeField) RangeField)
@ -14,11 +15,16 @@ from .models import Article, Reporter, Film, FilmDetails
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(help_text='Custom Help Text', *args, **kwargs) field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs)
graphene_type = convert_django_field(field) graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.as_field() field = graphene_type.as_field()
assert field.description == 'Custom Help Text' assert field.description == 'Custom Help Text'
if not isinstance(graphene_type, OfType):
nonnull_field = django_field(null=False, *args, **kwargs)
if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field)
assert isinstance(nonnull_graphene_type, graphene.NonNull)
return field return field
@ -176,8 +182,9 @@ def test_should_onetoone_reverse_convert_model():
related = getattr(Film.details, 'rel', None) or \ related = getattr(Film.details, 'rel', None) or \
getattr(Film.details, 'related') getattr(Film.details, 'related')
graphene_type = convert_django_field(related) graphene_type = convert_django_field(related)
assert isinstance(graphene_type, DjangoModelField) assert isinstance(graphene_type, graphene.NonNull)
assert graphene_type.model == FilmDetails assert isinstance(graphene_type.of_type, DjangoModelField)
assert graphene_type.of_type.model == FilmDetails
def test_should_onetoone_convert_model(): def test_should_onetoone_convert_model():
@ -195,7 +202,8 @@ def test_should_foreignkey_convert_model():
def test_should_postgres_array_convert_list(): def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
assert isinstance(field.type, graphene.List) assert isinstance(field.type, graphene.List)
assert isinstance(field.type.of_type, graphene.String) assert isinstance(field.type.of_type, graphene.NonNull)
assert isinstance(field.type.of_type.of_type, graphene.String)
@pytest.mark.skipif(ArrayField is MissingType, @pytest.mark.skipif(ArrayField is MissingType,
@ -204,7 +212,8 @@ def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
assert isinstance(field.type, graphene.List) assert isinstance(field.type, graphene.List)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.String) assert isinstance(field.type.of_type.of_type, graphene.NonNull)
assert isinstance(field.type.of_type.of_type.of_type, graphene.String)
@pytest.mark.skipif(HStoreField is MissingType, @pytest.mark.skipif(HStoreField is MissingType,
@ -224,4 +233,5 @@ def test_should_postgres_json_convert_string():
def test_should_postgres_range_convert_list(): def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List) field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type.of_type, graphene.Int) assert isinstance(field.type.of_type, graphene.NonNull)
assert isinstance(field.type.of_type.of_type, graphene.Int)