Fix choices in ChoiceField to support IntEnum (#8955)

Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case.
[https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum)

rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field.
This commit is contained in:
Burson 2023-07-13 20:50:49 +08:00 committed by GitHub
parent 4f7e9ed3bb
commit 66d86d0177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 11 deletions

View File

@ -8,6 +8,7 @@ import logging
import re import re
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum
from django.conf import settings from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
@ -17,7 +18,6 @@ from django.core.validators import (
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators URLValidator, ip_address_validators
) )
from django.db.models import IntegerChoices, TextChoices
from django.forms import FilePathField as DjangoFilePathField from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField from django.forms import ImageField as DjangoImageField
from django.utils import timezone from django.utils import timezone
@ -1401,11 +1401,8 @@ class ChoiceField(Field):
def to_internal_value(self, data): def to_internal_value(self, data):
if data == '' and self.allow_blank: if data == '' and self.allow_blank:
return '' return ''
if isinstance(data, Enum) and str(data) != str(data.value):
if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
str(data.value):
data = data.value data = data.value
try: try:
return self.choice_strings_to_values[str(data)] return self.choice_strings_to_values[str(data)]
except KeyError: except KeyError:
@ -1414,11 +1411,8 @@ class ChoiceField(Field):
def to_representation(self, value): def to_representation(self, value):
if value in ('', None): if value in ('', None):
return value return value
if isinstance(value, Enum) and str(value) != str(value.value):
if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
str(value.value):
value = value.value value = value.value
return self.choice_strings_to_values.get(str(value), value) return self.choice_strings_to_values.get(str(value), value)
def iter_options(self): def iter_options(self):
@ -1442,8 +1436,7 @@ class ChoiceField(Field):
# Allows us to deal with eg. integer choices while supporting either # Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out. # integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = { self.choice_strings_to_values = {
str(key.value) if isinstance(key, (IntegerChoices, TextChoices)) str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices
and str(key) != str(key.value) else str(key): key for key in self.choices
} }
choices = property(_get_choices, _set_choices) choices = property(_get_choices, _set_choices)
@ -1829,6 +1822,7 @@ class HiddenField(Field):
constraint on a pair of fields, as we need some way to include the date in constraint on a pair of fields, as we need some way to include the date in
the validated data. the validated data.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert 'default' in kwargs, 'default is a required argument.' assert 'default' in kwargs, 'default is a required argument.'
kwargs['write_only'] = True kwargs['write_only'] = True
@ -1858,6 +1852,7 @@ class SerializerMethodField(Field):
def get_extra_info(self, obj): def get_extra_info(self, obj):
return ... # Calculate some data to return. return ... # Calculate some data to return.
""" """
def __init__(self, method_name=None, **kwargs): def __init__(self, method_name=None, **kwargs):
self.method_name = method_name self.method_name = method_name
kwargs['source'] = '*' kwargs['source'] = '*'

View File

@ -1875,6 +1875,31 @@ class TestChoiceField(FieldValues):
field.run_validation(2) field.run_validation(2)
assert exc_info.value.detail == ['"2" is not a valid choice.'] assert exc_info.value.detail == ['"2" is not a valid choice.']
def test_enum_integer_choices(self):
from enum import IntEnum
class ChoiceCase(IntEnum):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "1"),
(ChoiceCase.second, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
# Enum.value validate
choices = [
(ChoiceCase.first.value, "1"),
(ChoiceCase.second.value, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
def test_integer_choices(self): def test_integer_choices(self):
class ChoiceCase(IntegerChoices): class ChoiceCase(IntegerChoices):
first = auto() first = auto()