diff --git a/examples/starwars/schema.py b/examples/starwars/schema.py index 492918e..2abf140 100644 --- a/examples/starwars/schema.py +++ b/examples/starwars/schema.py @@ -1,12 +1,14 @@ import graphene from graphene import Schema, relay, resolve_only_args from graphene_django import DjangoConnectionField, DjangoObjectType +from graphene_django.rest_framework.mutation import SerializerMutation from .data import (create_ship, get_empire, get_faction, get_rebels, get_ship, get_ships) from .models import Character as CharacterModel from .models import Faction as FactionModel from .models import Ship as ShipModel +from .serializers import CharacterSerializer class Ship(DjangoObjectType): @@ -54,6 +56,11 @@ class IntroduceShip(relay.ClientIDMutation): return IntroduceShip(ship=ship, faction=faction) +class CreateCharacter(SerializerMutation): + class Meta: + serializer_class = CharacterSerializer + + class Query(graphene.ObjectType): rebels = graphene.Field(Faction) empire = graphene.Field(Faction) @@ -75,7 +82,7 @@ class Query(graphene.ObjectType): class Mutation(graphene.ObjectType): introduce_ship = IntroduceShip.Field() - + create_character = CreateCharacter.Field() # We register the Character Model because if not would be # inaccessible for the schema diff --git a/examples/starwars/serializers.py b/examples/starwars/serializers.py new file mode 100644 index 0000000..386289a --- /dev/null +++ b/examples/starwars/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers + +from .models import Character + + +class CharacterSerializer(serializers.ModelSerializer): + class Meta: + model = Character + fields = "__all__" diff --git a/examples/starwars/tests/test_mutation.py b/examples/starwars/tests/test_mutation.py index aa312ff..1e105ec 100644 --- a/examples/starwars/tests/test_mutation.py +++ b/examples/starwars/tests/test_mutation.py @@ -77,3 +77,33 @@ def test_mutations(): result = schema.execute(query) assert not result.errors assert result.data == expected + + +def test_serializer_mutations(): + initialize() + + query = ''' + mutation createCharacter { + createCharacter(input:{clientMutationId:"def", name: "Luke", ship: "1"}) { + id + name + ship { + id + name + } + } + } + ''' + expected = { + 'createCharacter': { + 'id': 3, + 'name': 'Luke', + 'ship': { + 'id': 'U2hpcDox', + 'name': 'X-Wing' + } + } + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene_django/rest_framework/models.py b/graphene_django/rest_framework/models.py index 848837b..48a1d65 100644 --- a/graphene_django/rest_framework/models.py +++ b/graphene_django/rest_framework/models.py @@ -4,3 +4,9 @@ from django.db import models class MyFakeModel(models.Model): cool_name = models.CharField(max_length=50) created = models.DateTimeField(auto_now_add=True) + + +class OneToOneModel(models.Model): + name = models.CharField(max_length=50) + fake = models.ForeignKey(MyFakeModel, on_delete=models.DO_NOTHING) + created = models.DateTimeField(auto_now_add=True) diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index 5e343aa..b142d17 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -138,6 +138,9 @@ class SerializerMutation(ClientIDMutation): kwargs = {} for f, field in serializer.fields.items(): - kwargs[f] = field.get_attribute(obj) + if hasattr(field, 'queryset'): + kwargs[f] = field.queryset.get(pk=str(field.get_attribute(obj))) + else: + kwargs[f] = field.get_attribute(obj) return cls(errors=None, **kwargs) diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index 9f8e516..6c6c9b8 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -11,7 +11,7 @@ singledispatch = import_single_dispatch() @singledispatch -def get_graphene_type_from_serializer_field(field): +def get_graphene_type_from_serializer_field(field, **kwargs): raise ImproperlyConfigured( "Don't know how to convert the serializer field %s (%s) " "to Graphene type" % (field, field.__class__) @@ -25,7 +25,7 @@ def convert_serializer_field(field, is_input=True): and the field itself is required """ - graphql_type = get_graphene_type_from_serializer_field(field) + graphql_type = get_graphene_type_from_serializer_field(field, is_input=is_input) args = [] kwargs = {"description": field.help_text, "required": is_input and field.required} @@ -36,6 +36,11 @@ def convert_serializer_field(field, is_input=True): kwargs["of_type"] = graphql_type[1] graphql_type = graphql_type[0] + if isinstance(field, serializers.PrimaryKeyRelatedField) and not is_input: + global_registry = get_global_registry() + field_model = field.queryset.model + args = [global_registry.get_type_for_model(field_model)] + if isinstance(field, serializers.ModelSerializer): if is_input: graphql_type = convert_serializer_to_input_type(field.__class__) @@ -53,6 +58,7 @@ def convert_serializer_field(field, is_input=True): field_model = field.Meta.model args = [global_registry.get_type_for_model(field_model)] + print('graphql_type', graphql_type, args) return graphql_type(*args, **kwargs) @@ -72,49 +78,56 @@ def convert_serializer_to_input_type(serializer_class): @get_graphene_type_from_serializer_field.register(serializers.Field) -def convert_serializer_field_to_string(field): +def convert_serializer_field_to_string(field, **kwargs): return graphene.String @get_graphene_type_from_serializer_field.register(serializers.ModelSerializer) -def convert_serializer_to_field(field): +def convert_serializer_to_field(field, **kwargs): + return graphene.Field + + +@get_graphene_type_from_serializer_field.register(serializers.PrimaryKeyRelatedField) +def convert_serializer_key_to_field(field, is_input=True): + if is_input: + return graphene.String return graphene.Field @get_graphene_type_from_serializer_field.register(serializers.ListSerializer) -def convert_list_serializer_to_field(field): +def convert_list_serializer_to_field(field, **kwargs): child_type = get_graphene_type_from_serializer_field(field.child) return (graphene.List, child_type) @get_graphene_type_from_serializer_field.register(serializers.IntegerField) -def convert_serializer_field_to_int(field): +def convert_serializer_field_to_int(field, **kwargs): return graphene.Int @get_graphene_type_from_serializer_field.register(serializers.BooleanField) -def convert_serializer_field_to_bool(field): +def convert_serializer_field_to_bool(field, **kwargs): return graphene.Boolean @get_graphene_type_from_serializer_field.register(serializers.FloatField) @get_graphene_type_from_serializer_field.register(serializers.DecimalField) -def convert_serializer_field_to_float(field): +def convert_serializer_field_to_float(field, **kwargs): return graphene.Float @get_graphene_type_from_serializer_field.register(serializers.DateTimeField) -def convert_serializer_field_to_datetime_time(field): +def convert_serializer_field_to_datetime_time(field, **kwargs): return graphene.types.datetime.DateTime @get_graphene_type_from_serializer_field.register(serializers.DateField) -def convert_serializer_field_to_date_time(field): +def convert_serializer_field_to_date_time(field, **kwargs): return graphene.types.datetime.Date @get_graphene_type_from_serializer_field.register(serializers.TimeField) -def convert_serializer_field_to_time(field): +def convert_serializer_field_to_time(field, **kwargs): return graphene.types.datetime.Time @@ -126,15 +139,15 @@ def convert_serializer_field_to_list(field, is_input=True): @get_graphene_type_from_serializer_field.register(serializers.DictField) -def convert_serializer_field_to_dict(field): +def convert_serializer_field_to_dict(field, **kwargs): return DictType @get_graphene_type_from_serializer_field.register(serializers.JSONField) -def convert_serializer_field_to_jsonstring(field): +def convert_serializer_field_to_jsonstring(field, **kwargs): return graphene.types.json.JSONString @get_graphene_type_from_serializer_field.register(serializers.MultipleChoiceField) -def convert_serializer_field_to_list_of_string(field): +def convert_serializer_field_to_list_of_string(field, **kwargs): return (graphene.List, graphene.String) diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 4dccc18..ef1149a 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -8,6 +8,7 @@ from rest_framework import serializers from ...types import DjangoObjectType from ..models import MyFakeModel +from ..models import OneToOneModel from ..mutation import SerializerMutation @@ -32,6 +33,17 @@ class MyModelSerializer(serializers.ModelSerializer): fields = "__all__" +class OneToOneModelSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneModel + fields = "__all__" + + +class OneToOneModelMutation(SerializerMutation): + class Meta: + serializer_class = OneToOneModelSerializer + + class MyModelMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer @@ -64,6 +76,23 @@ def test_has_fields(): assert "errors" in MyMutation._meta.fields +def test_has_nested_fields(): + class MyFakeModelGrapheneType(DjangoObjectType): + class Meta: + model = MyFakeModel + + class OneToOneModelMutation(SerializerMutation): + class Meta: + serializer_class = OneToOneModelSerializer + + assert "name" in OneToOneModelMutation._meta.fields + assert "fake" in OneToOneModelMutation._meta.fields + model_field = OneToOneModelMutation._meta.fields['fake'] + assert isinstance(model_field, Field) + assert model_field.type == MyFakeModelGrapheneType + assert "errors" in OneToOneModelMutation._meta.fields + + def test_has_input_fields(): class MyMutation(SerializerMutation): class Meta: @@ -127,6 +156,21 @@ def test_model_add_mutate_and_get_payload_success(): assert isinstance(result.created, datetime.datetime) +@mark.django_db +def test_one_to_one_model_with_add_mutate_and_get_payload_success(): + fake = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) + + result = OneToOneModelMutation.mutate_and_get_payload( + None, mock_info(), **{"name": "Jinkies", "fake": fake.id} + ) + assert result.errors is None + assert result.name == "Jinkies" + assert result.fake.pk == fake.id + assert isinstance(result.created, datetime.datetime) + + @mark.django_db def test_model_update_mutate_and_get_payload_success(): instance = MyFakeModel.objects.create(cool_name="Narf") @@ -168,6 +212,18 @@ def test_model_mutate_and_get_payload_error(): assert len(result.errors) > 0 +@mark.django_db +def test_one_to_one_model_with_add_mutate_and_get_payload_error(): + MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) + + result = OneToOneModelMutation.mutate_and_get_payload( + None, mock_info(), **{"name": "Jinkies", "fake": "invalid"} + ) + assert len(result.errors) > 0 + + def test_invalid_serializer_operations(): with raises(Exception) as exc: