diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index beaaa49..6eef870 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -55,6 +55,7 @@ class SerializerMutation(ClientIDMutation): output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) _meta = SerializerMutationOptions(cls) + _meta.serializer_class = serializer_class _meta.fields = yank_fields_from_attrs( output_fields, _as=Field, diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index e115e82..c472cee 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -42,8 +42,7 @@ def convert_serializer_field(field, is_input=True): if isinstance(field, serializers.ModelSerializer): if is_input: - return Dynamic(lambda: None) - # graphql_type = convert_serializer_to_input_type(field.__class__) + graphql_type = convert_serializer_to_input_type(field.__class__) else: global_registry = get_global_registry() field_model = field.Meta.model @@ -52,6 +51,21 @@ def convert_serializer_field(field, is_input=True): return graphql_type(*args, **kwargs) +def convert_serializer_to_input_type(serializer_class): + serializer = serializer_class() + + items = { + name: convert_serializer_field(field) + for name, field in serializer.fields.items() + } + + return type( + '{}Input'.format(serializer.__class__.__name__), + (graphene.InputObjectType,), + items + ) + + @get_graphene_type_from_serializer_field.register(serializers.Field) def convert_serializer_field_to_string(field): return graphene.String diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 836f3fe..5374a66 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -65,7 +65,6 @@ def test_nested_model(): assert model_field.type == MyFakeModelGrapheneType model_input = MyMutation.Input._meta.fields['model'] - model_input_type = model_input.get_type() - assert not model_input_type - # assert issubclass(model_input_type, InputObjectType) - # assert 'cool_name' in model_input_type._meta.fields + model_input_type = model_input._type.of_type + assert issubclass(model_input_type, InputObjectType) + assert 'cool_name' in model_input_type._meta.fields