diff --git a/graphene_django/rest_framework/models.py b/graphene_django/rest_framework/models.py index bd84ce5..d31c3eb 100644 --- a/graphene_django/rest_framework/models.py +++ b/graphene_django/rest_framework/models.py @@ -14,3 +14,14 @@ class MyFakeModelWithPassword(models.Model): class MyFakeModelWithDate(models.Model): cool_name = models.CharField(max_length=50) last_edited = models.DateField() + + +class MyFakeModelWithChoiceField(models.Model): + class ChoiceType(models.Choices): + ASDF = "asdf" + HI = "hi" + + choice_type = models.CharField( + max_length=4, + default=ChoiceType.HI.name, + ) diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index b7393da..837db1e 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,3 +1,5 @@ +from enum import Enum + from collections import OrderedDict from django.shortcuts import get_object_or_404 @@ -124,8 +126,10 @@ class SerializerMutation(ClientIDMutation): def get_serializer_kwargs(cls, root, info, **input): lookup_field = cls._meta.lookup_field model_class = cls._meta.model_class - if model_class: + for input_dict_key, maybe_enum in input.items(): + if isinstance(maybe_enum, Enum): + input[input_dict_key] = maybe_enum.value if "update" in cls._meta.model_operations and lookup_field in input: instance = get_object_or_404( model_class, **{lookup_field: input[lookup_field]} diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 91d99f0..98cd11d 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -7,7 +7,12 @@ from graphene import Field, ResolveInfo from graphene.types.inputobjecttype import InputObjectType from ...types import DjangoObjectType -from ..models import MyFakeModel, MyFakeModelWithDate, MyFakeModelWithPassword +from ..models import ( + MyFakeModel, + MyFakeModelWithDate, + MyFakeModelWithPassword, + MyFakeModelWithChoiceField, +) from ..mutation import SerializerMutation @@ -268,6 +273,39 @@ def test_perform_mutate_success(): assert result.days_since_last_edit == 4 +def test_perform_mutate_success_with_enum_choice_field(): + class ListViewChoiceFieldSerializer(serializers.ModelSerializer): + choice_type = serializers.ChoiceField( + choices=[(x.name, x.value) for x in MyFakeModelWithChoiceField.ChoiceType], + required=False, + ) + + class Meta: + model = MyFakeModelWithChoiceField + fields = "__all__" + + class SomeCreateSerializerMutation(SerializerMutation): + class Meta: + serializer_class = ListViewChoiceFieldSerializer + + choice_type = { + "choice_type": SomeCreateSerializerMutation.Input.choice_type.type.get("ASDF") + } + name = MyFakeModelWithChoiceField.ChoiceType.ASDF.name + result = SomeCreateSerializerMutation.mutate_and_get_payload( + None, mock_info(), **choice_type + ) + assert result.errors is None + assert result.choice_type == name + kwargs = SomeCreateSerializerMutation.get_serializer_kwargs( + None, mock_info(), **choice_type + ) + assert kwargs["data"]["choice_type"] == name + assert 1 == MyFakeModelWithChoiceField.objects.count() + item = MyFakeModelWithChoiceField.objects.first() + assert item.choice_type == name + + def test_mutate_and_get_payload_error(): class MyMutation(SerializerMutation): class Meta: