diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index 055c3ac..8b04d46 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -3,6 +3,7 @@ from rest_framework import serializers import graphene +from ..registry import get_global_registry from ..utils import import_single_dispatch from .types import DictType @@ -41,6 +42,7 @@ def convert_serializer_field(field, is_input=True): graphql_type = get_graphene_type_from_serializer_field(field) + args = [] kwargs = { 'description': field.help_text, 'required': is_input and field.required, @@ -52,7 +54,15 @@ def convert_serializer_field(field, is_input=True): kwargs['of_type'] = graphql_type[1] graphql_type = graphql_type[0] - return graphql_type(**kwargs) + if isinstance(field, serializers.ModelSerializer): + if is_input: + graphql_type = convert_serializer_to_input_type(field.__class__) + else: + global_registry = get_global_registry() + field_model = field.Meta.model + args = [global_registry.get_type_for_model(field_model)] + + return graphql_type(*args, **kwargs) @get_graphene_type_from_serializer_field.register(serializers.Field) @@ -60,6 +70,11 @@ def convert_serializer_field_to_string(field): return graphene.String +@get_graphene_type_from_serializer_field.register(serializers.ModelSerializer) +def convert_serializer_to_field(field): + return graphene.Field + + @get_graphene_type_from_serializer_field.register(serializers.IntegerField) def convert_serializer_field_to_int(field): return graphene.Int @@ -76,6 +91,17 @@ def convert_serializer_field_to_float(field): return graphene.Float +@get_graphene_type_from_serializer_field.register(serializers.DateTimeField) +@get_graphene_type_from_serializer_field.register(serializers.DateField) +def convert_serializer_field_to_date_time(field): + return graphene.types.datetime.DateTime + + +@get_graphene_type_from_serializer_field.register(serializers.TimeField) +def convert_serializer_field_to_time(field): + return graphene.types.datetime.Time + + @get_graphene_type_from_serializer_field.register(serializers.ListField) def convert_serializer_field_to_list(field, is_input=True): child_type = get_graphene_type_from_serializer_field(field.child) diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py index 2248b6f..623cf58 100644 --- a/graphene_django/rest_framework/tests/test_field_converter.py +++ b/graphene_django/rest_framework/tests/test_field_converter.py @@ -8,7 +8,7 @@ from ..serializer_converter import convert_serializer_field from ..types import DictType -def _get_type(rest_framework_field, **kwargs): +def _get_type(rest_framework_field, is_input=True, **kwargs): # prevents the following error: # AssertionError: The `source` argument is not meaningful when applied to a `child=` field. # Remove `source=` from the field declaration. @@ -19,7 +19,7 @@ def _get_type(rest_framework_field, **kwargs): field = rest_framework_field(**kwargs) - return convert_serializer_field(field) + return convert_serializer_field(field, is_input=is_input) def assert_conversion(rest_framework_field, graphene_field, **kwargs): @@ -40,18 +40,6 @@ def test_should_unknown_rest_framework_field_raise_exception(): assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) -def test_should_date_convert_string(): - assert_conversion(serializers.DateField, graphene.String) - - -def test_should_time_convert_string(): - assert_conversion(serializers.TimeField, graphene.String) - - -def test_should_date_time_convert_string(): - assert_conversion(serializers.DateTimeField, graphene.String) - - def test_should_char_convert_string(): assert_conversion(serializers.CharField, graphene.String) @@ -85,6 +73,28 @@ def test_should_uuid_convert_string(): assert_conversion(serializers.UUIDField, graphene.String) +def test_should_model_convert_field(): + + class MyModelSerializer(serializers.ModelSerializer): + class Meta: + model = None + fields = '__all__' + + assert_conversion(MyModelSerializer, graphene.Field, is_input=False) + + +def test_should_date_time_convert_datetime(): + assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) + + +def test_should_date_convert_datetime(): + assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) + + +def test_should_time_convert_time(): + assert_conversion(serializers.TimeField, graphene.types.datetime.Time) + + def test_should_integer_convert_int(): assert_conversion(serializers.IntegerField, graphene.Int) diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 30ac477..5143f76 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -1,11 +1,26 @@ +from django.db import models +from graphene import Field +from graphene.types.inputobjecttype import InputObjectType from py.test import raises from rest_framework import serializers +from ...types import DjangoObjectType from ..mutation import SerializerMutation +class MyFakeModel(models.Model): + cool_name = models.CharField(max_length=50) + + +class MyModelSerializer(serializers.ModelSerializer): + class Meta: + model = MyFakeModel + fields = '__all__' + + class MySerializer(serializers.Serializer): text = serializers.CharField() + model = MyModelSerializer() def test_needs_serializer_class(): @@ -22,6 +37,7 @@ def test_has_fields(): serializer_class = MySerializer assert 'text' in MyMutation._meta.fields + assert 'model' in MyMutation._meta.fields assert 'errors' in MyMutation._meta.fields @@ -31,5 +47,24 @@ def test_has_input_fields(): serializer_class = MySerializer assert 'text' in MyMutation.Input._meta.fields + assert 'model' in MyMutation.Input._meta.fields +def test_nested_model(): + + class MyFakeModelGrapheneType(DjangoObjectType): + class Meta: + model = MyFakeModel + + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + model_field = MyMutation._meta.fields['model'] + assert isinstance(model_field, Field) + assert model_field.type == MyFakeModelGrapheneType + + model_input = MyMutation.Input._meta.fields['model'] + model_input_type = model_input._type.of_type + assert issubclass(model_input_type, InputObjectType) + assert 'cool_name' in model_input_type._meta.fields