diff --git a/rest_framework/fields.py b/rest_framework/fields.py index bedc02b94..0d144da23 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1412,13 +1412,12 @@ class ChoiceField(Field): html_cutoff = None html_cutoff_text = _('More than {count} items...') - def __init__(self, choices, **kwargs): + def __init__(self, choices, allow_blank=False, **kwargs): self.choices = choices + self.allow_blank = allow_blank self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - self.allow_blank = kwargs.pop('allow_blank', False) - super().__init__(**kwargs) def to_internal_value(self, data): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 49eec8259..918ae26d8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1212,20 +1212,19 @@ class ModelSerializer(Serializer): field_class = self.serializer_related_field field_kwargs['queryset'] = model_field.related_model.objects - if 'choices' in field_kwargs: + if 'choices' in field_kwargs and not issubclass(field_class, self.serializer_choice_field): # Fields with choices get coerced into `ChoiceField` # instead of using their regular typed field. field_class = self.serializer_choice_field # Some model fields may introduce kwargs that would not be valid # for the choice field. We need to strip these out. # Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES) - valid_kwargs = { - 'read_only', 'write_only', - 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', - 'error_messages', 'validators', 'allow_null', 'allow_blank', - 'choices' - } + valid_kwargs = set() + for c in inspect.getmro(field_class): + sig = inspect.signature(c.__init__) + for param in sig.parameters.values(): + if (param.kind == param.POSITIONAL_OR_KEYWORD): + valid_kwargs.add(param.name) for key in list(field_kwargs): if key not in valid_kwargs: field_kwargs.pop(key) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 7da1b41ae..619ce7b35 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -217,6 +217,25 @@ class TestRegularFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) + def test_override_choice_field_mapping(self): + class CustomChoiceField(models.CharField): + """ + A custom choice model field simply for testing purposes. + """ + max_length = 100 + + class CostomizedChoiceModel(models.Model): + choices_field = CustomChoiceField(choices=COLOR_CHOICES) + + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = CostomizedChoiceModel + fields = '__all__' + + self.assertTrue(isinstance(TestSerializer().fields["choices_field"], serializers.ChoiceField)) + TestSerializer.serializer_field_mapping[CustomChoiceField] = serializers.MultipleChoiceField + self.assertTrue(isinstance(TestSerializer().fields["choices_field"], serializers.MultipleChoiceField)) + def test_nullable_boolean_field_choices(self): class NullableBooleanChoicesModel(models.Model): CHECKLIST_OPTIONS = (