diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3ca7d682e..812a35ac2 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -5,6 +5,7 @@ import copy import datetime import decimal import inspect +import itertools import re import uuid @@ -1098,17 +1099,8 @@ class ChoiceField(Field): } def __init__(self, choices, **kwargs): - # Allow either single or paired choices style: - # choices = [1, 2, 3] - # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] - pairs = [ - isinstance(item, (list, tuple)) and len(item) == 2 - for item in choices - ] - if all(pairs): - self.choices = OrderedDict([(key, display_value) for key, display_value in choices]) - else: - self.choices = OrderedDict([(item, item) for item in choices]) + flat_choices = [self.flatten_choice(c) for c in choices] + self.choices = OrderedDict(itertools.chain(*flat_choices)) # Map the string representation of choices to the underlying value. # Allows us to deal with eg. integer choices while supporting either @@ -1121,6 +1113,30 @@ class ChoiceField(Field): super(ChoiceField, self).__init__(**kwargs) + def flatten_choice(self, choice): + """ + Convert a choices choice into a flat list of choices. + + Returns a list of choices. + """ + + # Allow single, paired or grouped choices style: + # choices = [1, 2, 3] + # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] + # choices = [('Category', ((1, 'First'), (2, 'Second'))), (3, 'Third')] + if (not isinstance(choice, (list, tuple))): + # single choice + return [(choice, choice)] + else: + key, display_value = choice + if isinstance(display_value, (list, tuple)): + # grouped choices + sub_choices = [self.flatten_choice(c) for c in display_value] + return list(itertools.chain(*sub_choices)) + else: + # paired choice + return [(key, display_value)] + def to_internal_value(self, data): if data == '' and self.allow_blank: return '' diff --git a/tests/test_fields.py b/tests/test_fields.py index 76e6d9d60..3bd4afc6e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1091,6 +1091,66 @@ class TestChoiceFieldWithListChoices(FieldValues): field = serializers.ChoiceField(choices=('poor', 'medium', 'good')) +class TestChoiceFieldWithGroupedChoices(FieldValues): + """ + Valid and invalid values for a `Choice` field that uses a grouped list for the + choices, rather than a list of pairs of (`value`, `description`). + """ + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'awful': ['"awful" is not a valid choice.'] + } + outputs = { + 'good': 'good' + } + field = serializers.ChoiceField( + choices=[ + ( + 'Category', + ( + ('poor', 'Poor quality'), + ('medium', 'Medium quality'), + ), + ), + ('good', 'Good quality'), + ] + ) + + +class TestChoiceFieldWithMixedChoices(FieldValues): + """ + Valid and invalid values for a `Choice` field that uses a single paired or + grouped. + """ + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'awful': ['"awful" is not a valid choice.'] + } + outputs = { + 'good': 'good' + } + field = serializers.ChoiceField( + choices=[ + ( + 'Category', + ( + ('poor', 'Poor quality'), + ), + ), + 'medium', + ('good', 'Good quality'), + ] + ) + + class TestMultipleChoiceField(FieldValues): """ Valid and invalid values for `MultipleChoiceField`. diff --git a/tests/test_validation.py b/tests/test_validation.py index 46e36f5d8..855ff20e0 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -141,6 +141,8 @@ class TestMaxValueValidatorValidation(TestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +# regression tests for issue: 1533 + class TestChoiceFieldChoicesValidate(TestCase): CHOICES = [ (0, 'Small'), @@ -148,6 +150,8 @@ class TestChoiceFieldChoicesValidate(TestCase): (2, 'Large'), ] + SINGLE_CHOICES = [0, 1, 2] + CHOICES_NESTED = [ ('Category', ( (1, 'First'), @@ -157,6 +161,15 @@ class TestChoiceFieldChoicesValidate(TestCase): (4, 'Fourth'), ] + MIXED_CHOICES = [ + ('Category', ( + (1, 'First'), + (2, 'Second'), + )), + 3, + (4, 'Fourth'), + ] + def test_choices(self): """ Make sure a value for choices works as expected. @@ -168,6 +181,39 @@ class TestChoiceFieldChoicesValidate(TestCase): except serializers.ValidationError: self.fail("Value %s does not validate" % str(value)) + def test_single_choices(self): + """ + Make sure a single value for choices works as expected. + """ + f = serializers.ChoiceField(choices=self.SINGLE_CHOICES) + value = self.SINGLE_CHOICES[0] + try: + f.to_internal_value(value) + except serializers.ValidationError: + self.fail("Value %s does not validate" % str(value)) + + def test_nested_choices(self): + """ + Make sure a nested value for choices works as expected. + """ + f = serializers.ChoiceField(choices=self.CHOICES_NESTED) + value = self.CHOICES_NESTED[0][1][0][0] + try: + f.to_internal_value(value) + except serializers.ValidationError: + self.fail("Value %s does not validate" % str(value)) + + def test_mixed_choices(self): + """ + Make sure mixed values for choices works as expected. + """ + f = serializers.ChoiceField(choices=self.MIXED_CHOICES) + value = self.MIXED_CHOICES[1] + try: + f.to_internal_value(value) + except serializers.ValidationError: + self.fail("Value %s does not validate" % str(value)) + class RegexSerializer(serializers.Serializer): pin = serializers.CharField(