Ensure that html forms (multipart form data) respect optional fields (#5927)

This commit is contained in:
Christian Kreuzberger 2018-04-20 15:11:52 +02:00 committed by Carlton Gibson
parent 7e705246ca
commit f148e4e259
6 changed files with 129 additions and 6 deletions

View File

@ -1614,7 +1614,8 @@ class ListField(Field):
if len(val) > 0: if len(val) > 0:
# Support QueryDict lists in HTML input. # Support QueryDict lists in HTML input.
return val return val
return html.parse_html_list(dictionary, prefix=self.field_name) return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
return dictionary.get(self.field_name, empty) return dictionary.get(self.field_name, empty)
def to_internal_value(self, data): def to_internal_value(self, data):
@ -1622,7 +1623,7 @@ class ListField(Field):
List of dicts of native values <- List of dicts of primitive datatypes. List of dicts of native values <- List of dicts of primitive datatypes.
""" """
if html.is_html_input(data): if html.is_html_input(data):
data = html.parse_html_list(data) data = html.parse_html_list(data, default=[])
if isinstance(data, type('')) or isinstance(data, collections.Mapping) or not hasattr(data, '__iter__'): if isinstance(data, type('')) or isinstance(data, collections.Mapping) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__) self.fail('not_a_list', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:

View File

@ -607,7 +607,7 @@ class ListSerializer(BaseSerializer):
# We override the default field access in order to support # We override the default field access in order to support
# lists in HTML forms. # lists in HTML forms.
if html.is_html_input(dictionary): if html.is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name) return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
return dictionary.get(self.field_name, empty) return dictionary.get(self.field_name, empty)
def run_validation(self, data=empty): def run_validation(self, data=empty):
@ -635,7 +635,7 @@ class ListSerializer(BaseSerializer):
List of dicts of native values <- List of dicts of primitive datatypes. List of dicts of native values <- List of dicts of primitive datatypes.
""" """
if html.is_html_input(data): if html.is_html_input(data):
data = html.parse_html_list(data) data = html.parse_html_list(data, default=[])
if not isinstance(data, list): if not isinstance(data, list):
message = self.error_messages['not_a_list'].format( message = self.error_messages['not_a_list'].format(

View File

@ -12,7 +12,7 @@ def is_html_input(dictionary):
return hasattr(dictionary, 'getlist') return hasattr(dictionary, 'getlist')
def parse_html_list(dictionary, prefix=''): def parse_html_list(dictionary, prefix='', default=None):
""" """
Used to support list values in HTML forms. Used to support list values in HTML forms.
Supports lists of primitives and/or dictionaries. Supports lists of primitives and/or dictionaries.
@ -44,6 +44,8 @@ def parse_html_list(dictionary, prefix=''):
{'foo': 'abc', 'bar': 'def'}, {'foo': 'abc', 'bar': 'def'},
{'foo': 'hij', 'bar': 'klm'} {'foo': 'hij', 'bar': 'klm'}
] ]
:returns a list of objects, or the value specified in ``default`` if the list is empty
""" """
ret = {} ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
@ -59,7 +61,9 @@ def parse_html_list(dictionary, prefix=''):
ret[index][key] = value ret[index][key] = value
else: else:
ret[index] = MultiValueDict({key: [value]}) ret[index] = MultiValueDict({key: [value]})
return [ret[item] for item in sorted(ret)]
# return the items of the ``ret`` dict, sorted by key, or ``default`` if the dict is empty
return [ret[item] for item in sorted(ret)] if ret else default
def parse_html_dict(dictionary, prefix=''): def parse_html_dict(dictionary, prefix=''):

View File

@ -466,6 +466,55 @@ class TestHTMLInput:
assert serializer.is_valid() assert serializer.is_valid()
assert serializer.validated_data == {'scores': [1]} assert serializer.validated_data == {'scores': [1]}
def test_querydict_list_input_no_values_uses_default(self):
"""
When there are no values passed in, and default is set
The field should return the default value
"""
class TestSerializer(serializers.Serializer):
a = serializers.IntegerField(required=True)
scores = serializers.ListField(default=lambda: [1, 3])
serializer = TestSerializer(data=QueryDict('a=1&'))
assert serializer.is_valid()
assert serializer.validated_data == {'a': 1, 'scores': [1, 3]}
def test_querydict_list_input_supports_indexed_keys(self):
"""
When data is passed in the format `scores[0]=1&scores[1]=3`
The field should return the correct list, ignoring the default
"""
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(default=lambda: [1, 3])
serializer = TestSerializer(data=QueryDict("scores[0]=5&scores[1]=6"))
assert serializer.is_valid()
assert serializer.validated_data == {'scores': ['5', '6']}
def test_querydict_list_input_no_values_no_default_and_not_required(self):
"""
When there are no keys passed, there is no default, and required=False
The field should be skipped
"""
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(required=False)
serializer = TestSerializer(data=QueryDict(''))
assert serializer.is_valid()
assert serializer.validated_data == {}
def test_querydict_list_input_posts_key_but_no_values(self):
"""
When there are no keys passed, there is no default, and required=False
The field should return an array of 1 item, blank
"""
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(required=False)
serializer = TestSerializer(data=QueryDict('scores=&'))
assert serializer.is_valid()
assert serializer.validated_data == {'scores': ['']}
class TestCreateOnlyDefault: class TestCreateOnlyDefault:
def setup(self): def setup(self):

View File

@ -1,3 +1,4 @@
from django.http import QueryDict
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from rest_framework import serializers from rest_framework import serializers
@ -532,3 +533,32 @@ class TestSerializerPartialUsage:
assert value == updated_data_list[index][key] assert value == updated_data_list[index][key]
assert serializer.errors == {} assert serializer.errors == {}
class TestEmptyListSerializer:
"""
Tests the behaviour of ListSerializers when there is no data passed to it
"""
def setup(self):
class ExampleListSerializer(serializers.ListSerializer):
child = serializers.IntegerField()
self.Serializer = ExampleListSerializer
def test_nested_serializer_with_list_json(self):
# pass an empty array to the serializer
input_data = []
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.validated_data == []
def test_nested_serializer_with_list_multipart(self):
# pass an "empty" QueryDict to the serializer (should be the same as an empty array)
input_data = QueryDict('')
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.validated_data == []

View File

@ -202,3 +202,42 @@ class TestNestedSerializerWithList:
assert serializer.is_valid() assert serializer.is_valid()
assert serializer.validated_data['nested']['example'] == {1, 2} assert serializer.validated_data['nested']['example'] == {1, 2}
class TestNotRequiredNestedSerializerWithMany:
def setup(self):
class NestedSerializer(serializers.Serializer):
one = serializers.IntegerField(max_value=10)
class TestSerializer(serializers.Serializer):
nested = NestedSerializer(required=False, many=True)
self.Serializer = TestSerializer
def test_json_validate(self):
input_data = {}
serializer = self.Serializer(data=input_data)
# request is empty, therefor 'nested' should not be in serializer.data
assert serializer.is_valid()
assert 'nested' not in serializer.validated_data
input_data = {'nested': [{'one': '1'}, {'one': 2}]}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert 'nested' in serializer.validated_data
def test_multipart_validate(self):
# leave querydict empty
input_data = QueryDict('')
serializer = self.Serializer(data=input_data)
# the querydict is empty, therefor 'nested' should not be in serializer.data
assert serializer.is_valid()
assert 'nested' not in serializer.validated_data
input_data = QueryDict('nested[0]one=1&nested[1]one=2')
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert 'nested' in serializer.validated_data