diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8e15345da..c0253f86b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -563,7 +563,7 @@ class ChoiceField(WritableField): if isinstance(v, (list, tuple)): # This is an optgroup, so look inside the group for options for k2, v2 in v: - if value == smart_text(k2): + if value == smart_text(k2) or value == k2: return True else: if value == smart_text(k) or value == k: diff --git a/tests/test_validation.py b/tests/test_validation.py index e13e4078c..a46e38ac5 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.core.validators import MaxValueValidator +from django.core.exceptions import ValidationError from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status @@ -146,3 +147,42 @@ class TestMaxValueValidatorValidation(TestCase): response = view(request, pk=obj.pk).render() self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +class TestChoiceFieldChoicesValidate(TestCase): + CHOICES = [ + (0, 'Small'), + (1, 'Medium'), + (2, 'Large'), + ] + + CHOICES_NESTED = [ + ('Category', ( + (1, 'First'), + (2, 'Second'), + (3, 'Third'), + )), + (4, 'Fourth'), + ] + + def test_choices(self): + """ + Make sure a value for choices works as expected. + """ + f = serializers.ChoiceField(choices=self.CHOICES) + value = self.CHOICES[0][0] + try: + f.validate(value) + except ValidationError: + self.fail("Value %s does not validate" % str(value)) + + def test_nested_choices(self): + """ + Make sure a nested value for choices works as expected. + """ + f = serializers.ChoiceField(choices=self.CHOICES_NESTED) + value = self.CHOICES_NESTED[0][1][0][0] + try: + f.validate(value) + except ValidationError: + self.fail("Value %s does not validate" % str(value))