diff --git a/rest_framework/fields.py b/rest_framework/fields.py index eea693442..736acf3db 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1141,10 +1141,15 @@ class ChoiceField(Field): class MultipleChoiceField(ChoiceField): default_error_messages = { 'invalid_choice': _('"{input}" is not a valid choice.'), - 'not_a_list': _('Expected a list of items but got type "{input_type}".') + 'not_a_list': _('Expected a list of items but got type "{input_type}".'), + 'empty': _('This selection may not be empty.') } default_empty_html = [] + def __init__(self, *args, **kwargs): + self.allow_empty = kwargs.pop('allow_empty', True) + super(MultipleChoiceField, self).__init__(*args, **kwargs) + def get_value(self, dictionary): # We override the default field access in order to support # lists in HTML forms. @@ -1159,6 +1164,8 @@ class MultipleChoiceField(ChoiceField): def to_internal_value(self, data): if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail('empty') return set([ super(MultipleChoiceField, self).to_internal_value(item) @@ -1263,11 +1270,13 @@ class ListField(Field): child = _UnvalidatedField() initial = [] default_error_messages = { - 'not_a_list': _('Expected a list of items but got type "{input_type}".') + 'not_a_list': _('Expected a list of items but got type "{input_type}".'), + 'empty': _('This list may not be empty.') } def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) + self.allow_empty = kwargs.pop('allow_empty', True) assert not inspect.isclass(self.child), '`child` has not been instantiated.' super(ListField, self).__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) @@ -1287,6 +1296,8 @@ class ListField(Field): data = html.parse_html_list(data) if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail('empty') return [self.child.run_validation(item) for item in data] def to_representation(self, data): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 8bf7d628e..97a6417ea 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -32,7 +32,7 @@ class PKOnlyObject(object): # rather than the parent serializer. MANY_RELATION_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', 'error_messages' + 'label', 'help_text', 'style', 'error_messages', 'allow_empty' ) @@ -366,9 +366,14 @@ class ManyRelatedField(Field): """ initial = [] default_empty_html = [] + default_error_messages = { + 'not_a_list': _('Expected a list of items but got type "{input_type}".'), + 'empty': _('This list may not be empty.') + } def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation + self.allow_empty = kwargs.pop('allow_empty', True) assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelatedField, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) @@ -386,6 +391,11 @@ class ManyRelatedField(Field): return dictionary.get(self.field_name, empty) def to_internal_value(self, data): + if isinstance(data, type('')) or not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail('empty') + return [ self.child_relation.to_internal_value(item) for item in data diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index beaae1e12..3c345c00d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -49,7 +49,7 @@ from rest_framework.relations import * # NOQA # isort:skip # rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', 'error_messages', + 'label', 'help_text', 'style', 'error_messages', 'allow_empty', 'instance', 'data', 'partial', 'context' ) @@ -493,11 +493,13 @@ class ListSerializer(BaseSerializer): many = True default_error_messages = { - 'not_a_list': _('Expected a list of items but got type "{input_type}".') + 'not_a_list': _('Expected a list of items but got type "{input_type}".'), + 'empty': _('This list may not be empty.') } def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) + self.allow_empty = kwargs.pop('allow_empty', True) assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' super(ListSerializer, self).__init__(*args, **kwargs) @@ -553,6 +555,12 @@ class ListSerializer(BaseSerializer): api_settings.NON_FIELD_ERRORS_KEY: [message] }) + if not self.allow_empty and len(data) == 0: + message = self.error_messages['empty'] + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }) + ret = [] errors = [] diff --git a/tests/test_fields.py b/tests/test_fields.py index cea440c82..897003df1 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1140,6 +1140,27 @@ class TestMultipleChoiceField(FieldValues): assert field.get_value(QueryDict({})) == rest_framework.fields.empty +class TestEmptyMultipleChoiceField(FieldValues): + """ + Invalid values for `MultipleChoiceField(allow_empty=False)`. + """ + valid_inputs = { + } + invalid_inputs = ( + ([], ['This selection may not be empty.']), + ) + outputs = [ + ] + field = serializers.MultipleChoiceField( + choices=[ + ('consistency', 'Consistency'), + ('availability', 'Availability'), + ('partition', 'Partition tolerance'), + ], + allow_empty=False + ) + + # File serializers... class MockFile: @@ -1233,7 +1254,8 @@ class TestListField(FieldValues): """ valid_inputs = [ ([1, 2, 3], [1, 2, 3]), - (['1', '2', '3'], [1, 2, 3]) + (['1', '2', '3'], [1, 2, 3]), + ([], []) ] invalid_inputs = [ ('not a list', ['Expected a list of items but got type "str".']), @@ -1246,6 +1268,18 @@ class TestListField(FieldValues): field = serializers.ListField(child=serializers.IntegerField()) +class TestEmptyListField(FieldValues): + """ + Values for `ListField` with allow_empty=False flag. + """ + valid_inputs = {} + invalid_inputs = [ + ([], ['This list may not be empty.']) + ] + outputs = {} + field = serializers.ListField(child=serializers.IntegerField(), allow_empty=False) + + class TestUnvalidatedListField(FieldValues): """ Values for `ListField` with no `child` argument.