diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e41b56fb0..ccc6bb6c7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -9,6 +9,7 @@ import uuid from collections import OrderedDict from collections.abc import Mapping +from django.db.models import IntegerChoices from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError @@ -1398,6 +1399,9 @@ class ChoiceField(Field): if data == '' and self.allow_blank: return '' + if isinstance(data, IntegerChoices) and str(data) != str(data.value): + data = data.value + try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1406,6 +1410,10 @@ class ChoiceField(Field): def to_representation(self, value): if value in ('', None): return value + + if isinstance(value, IntegerChoices) and str(value) != str(value.value): + value = value.value + return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1429,7 +1437,8 @@ class ChoiceField(Field): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key): key for key in self.choices + str(key.value) if isinstance(key, IntegerChoices) and str(key) != str( + key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) diff --git a/tests/test_fields.py b/tests/test_fields.py index 5804d7b3b..03ae4a704 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1824,6 +1824,34 @@ class TestChoiceField(FieldValues): field.run_validation(2) assert exc_info.value.detail == ['"2" is not a valid choice.'] + def test_enum_choices(self): + from enum import auto + from django.db.models import IntegerChoices + + class ChoiceCase(IntegerChoices): + first = auto() + second = auto() + # Enum validate + choices = [ + (ChoiceCase.first, "1"), + (ChoiceCase.second, "2") + ] + field = serializers.ChoiceField(choices=choices) + + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + + choices = [ + (ChoiceCase.first.value, "1"), + (ChoiceCase.second.value, "2") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + class TestChoiceFieldWithType(FieldValues): """