diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3a2cdd088..483c4b0d0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1692,22 +1692,14 @@ class ListField(Field): if hasattr(self.child, '_propagate_depth_to_child'): self.child._propagate_depth_to_child() - def _check_data_depth(self, data, current=0): - if self._root_max_depth is not None: - if isinstance(data, (list, tuple)): - for item in data: - if isinstance(item, (list, tuple, dict)): - next_depth = current + 1 - if next_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(item, next_depth) - elif isinstance(data, dict): - for value in data.values(): - if isinstance(value, (list, tuple, dict)): - next_depth = current + 1 - if next_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(value, next_depth) + def _check_data_depth(self, data, current_level): + items = data.values() if isinstance(data, dict) else data + for item in items: + if isinstance(item, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + self.fail('max_depth', max_depth=self._root_max_depth) + self._check_data_depth(item, next_level) def get_value(self, dictionary): if self.field_name not in dictionary: @@ -1734,9 +1726,12 @@ class ListField(Field): self.fail('not_a_list', input_type=type(data).__name__) if not self.allow_empty and len(data) == 0: self.fail('empty') - if self._root_max_depth is not None and self._current_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(data, self._current_depth) + if self._root_max_depth is not None: + start_level = self._current_depth if self._current_depth > 0 else 1 + if start_level > self._root_max_depth: + self.fail('max_depth', max_depth=self._root_max_depth) + if self.max_depth is not None: + self._check_data_depth(data, start_level) return self.run_child_validation(data) def to_representation(self, data): @@ -1789,7 +1784,7 @@ class DictField(Field): def bind(self, field_name, parent): super().bind(field_name, parent) - if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None: + if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None: self._root_max_depth = parent._root_max_depth self._current_depth = parent._current_depth + 1 self._propagate_depth_to_child() @@ -1801,22 +1796,14 @@ class DictField(Field): if hasattr(self.child, '_propagate_depth_to_child'): self.child._propagate_depth_to_child() - def _check_data_depth(self, data, current=0): - if self._root_max_depth is not None: - if isinstance(data, dict): - for value in data.values(): - if isinstance(value, (list, tuple, dict)): - next_depth = current + 1 - if next_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(value, next_depth) - elif isinstance(data, (list, tuple)): - for item in data: - if isinstance(item, (list, tuple, dict)): - next_depth = current + 1 - if next_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(item, next_depth) + def _check_data_depth(self, data, current_level): + items = data.values() if isinstance(data, dict) else data + for item in items: + if isinstance(item, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + self.fail('max_depth', max_depth=self._root_max_depth) + self._check_data_depth(item, next_level) def get_value(self, dictionary): # We override the default field access in order to support @@ -1835,9 +1822,12 @@ class DictField(Field): self.fail('not_a_dict', input_type=type(data).__name__) if not self.allow_empty and len(data) == 0: self.fail('empty') - if self._root_max_depth is not None and self._current_depth > self._root_max_depth: - self.fail('max_depth', max_depth=self._root_max_depth) - self._check_data_depth(data, self._current_depth) + if self._root_max_depth is not None: + start_level = self._current_depth if self._current_depth > 0 else 1 + if start_level > self._root_max_depth: + self.fail('max_depth', max_depth=self._root_max_depth) + if self.max_depth is not None: + self._check_data_depth(data, start_level) return self.run_child_validation(data) def to_representation(self, value): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 289237649..f2d24174c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -76,9 +76,9 @@ LIST_SERIALIZER_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'label', 'help_text', 'style', 'error_messages', 'allow_empty', 'instance', 'data', 'partial', 'context', 'allow_null', - 'max_length', 'min_length' + 'max_length', 'min_length', 'max_depth' ) -LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length') +LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length', 'max_depth') ALL_FIELDS = '__all__' @@ -111,6 +111,10 @@ class BaseSerializer(Field): .data - Available. """ + default_error_messages = { + 'max_depth': _('Nesting depth exceeds maximum allowed depth of {max_depth}.') + } + def __init__(self, instance=None, data=empty, **kwargs): self.instance = instance if data is not empty: @@ -118,13 +122,14 @@ class BaseSerializer(Field): self.partial = kwargs.pop('partial', False) self._context = kwargs.pop('context', {}) kwargs.pop('many', None) + self.max_depth = kwargs.pop('max_depth', None) super().__init__(**kwargs) self._current_depth = 0 - self._root_max_depth = None + self._root_max_depth = self.max_depth def bind(self, field_name, parent): super().bind(field_name, parent) - if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None: + if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None: self._root_max_depth = parent._root_max_depth self._current_depth = parent._current_depth + 1 @@ -137,6 +142,32 @@ class BaseSerializer(Field): if hasattr(field, '_propagate_depth_to_child'): field._propagate_depth_to_child() + def _check_data_depth(self, data, current_level): + if isinstance(data, dict): + for value in data.values(): + if isinstance(value, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + message = self.error_messages['max_depth'].format( + max_depth=self._root_max_depth + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='max_depth') + self._check_data_depth(value, next_level) + elif isinstance(data, (list, tuple)): + for item in data: + if isinstance(item, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + message = self.error_messages['max_depth'].format( + max_depth=self._root_max_depth + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='max_depth') + self._check_data_depth(item, next_level) + def __new__(cls, *args, **kwargs): # We override this method in order to automatically create # `ListSerializer` classes instead when `many=True` is set. @@ -390,7 +421,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): fields = BindingDict(self) for key, value in self.get_fields().items(): fields[key] = value - self._propagate_depth_to_child() + if self._root_max_depth is not None: + self._propagate_depth_to_child() return fields @property @@ -507,6 +539,9 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [message] }, code='invalid') + if self._root_max_depth is not None and self.max_depth is not None: + start_level = self._current_depth + self._check_data_depth(data, start_level) ret = {} errors = {} @@ -672,6 +707,32 @@ class ListSerializer(BaseSerializer): """ return self.child.run_validation(data) + def _check_data_depth(self, data, current_level): + if isinstance(data, (list, tuple)): + for item in data: + if isinstance(item, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + message = self.error_messages['max_depth'].format( + max_depth=self._root_max_depth + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='max_depth') + self._check_data_depth(item, next_level) + elif isinstance(data, dict): + for value in data.values(): + if isinstance(value, (list, tuple, dict)): + next_level = current_level + 1 + if next_level > self._root_max_depth: + message = self.error_messages['max_depth'].format( + max_depth=self._root_max_depth + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='max_depth') + self._check_data_depth(value, next_level) + def to_internal_value(self, data): """ List of dicts of native values <- List of dicts of primitive datatypes. @@ -687,6 +748,10 @@ class ListSerializer(BaseSerializer): api_settings.NON_FIELD_ERRORS_KEY: [message] }, code='not_a_list') + if self._root_max_depth is not None and self.max_depth is not None: + start_level = self._current_depth + self._check_data_depth(data, start_level) + if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ diff --git a/tests/test_fields.py b/tests/test_fields.py index 162997e43..6864db065 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2559,32 +2559,35 @@ class TestListFieldMaxDepth: output = field.run_validation([[[1, 2]], [[3]]]) assert output == [[[1, 2]], [[3]]] - def test_max_depth_zero_allows_field_itself(self): + def test_max_depth_zero_rejects_everything(self): field = serializers.ListField(child=serializers.IntegerField(), max_depth=0) + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation([1, 2, 3]) + assert 'max_depth' in str(exc_info.value.detail) + + def test_max_depth_one_allows_flat_list(self): + field = serializers.ListField(child=serializers.IntegerField(), max_depth=1) output = field.run_validation([1, 2, 3]) assert output == [1, 2, 3] - def test_max_depth_zero_rejects_nested_list(self): - field = serializers.ListField( - child=serializers.ListField(child=serializers.IntegerField()), - max_depth=0 - ) + def test_max_depth_one_rejects_nested_list(self): + field = serializers.ListField(child=serializers.IntegerField(), max_depth=1) with pytest.raises(serializers.ValidationError) as exc_info: - field.run_validation([[1, 2], [3]]) + field.run_validation([[1, 2]]) assert 'max_depth' in str(exc_info.value.detail) - def test_max_depth_one_allows_one_level_nesting(self): + def test_max_depth_two_allows_one_level_nesting(self): field = serializers.ListField( child=serializers.ListField(child=serializers.IntegerField()), - max_depth=1 + max_depth=2 ) output = field.run_validation([[1, 2], [3, 4]]) assert output == [[1, 2], [3, 4]] - def test_max_depth_one_rejects_two_levels_nesting(self): + def test_max_depth_two_rejects_two_levels_nesting(self): field = serializers.ListField( - child=serializers.ListField(child=serializers.ListField(child=serializers.IntegerField())), - max_depth=1 + child=serializers.ListField(child=serializers.IntegerField()), + max_depth=2 ) with pytest.raises(serializers.ValidationError) as exc_info: field.run_validation([[[1, 2]], [[3]]]) @@ -2607,7 +2610,7 @@ class TestListFieldMaxDepth: def test_max_depth_with_mixed_nesting_list_and_dict(self): field = serializers.ListField( child=serializers.DictField(child=serializers.ListField(child=serializers.IntegerField())), - max_depth=2 + max_depth=3 ) output = field.run_validation([{'a': [1, 2], 'b': [3]}]) assert output == [{'a': [1, 2], 'b': [3]}] @@ -2647,27 +2650,16 @@ class TestDictFieldMaxDepth: output = field.run_validation({'a': {'b': {'c': 1}}}) assert output == {'a': {'b': {'c': 1}}} - def test_max_depth_zero_allows_field_itself(self): + def test_max_depth_zero_rejects_everything(self): field = serializers.DictField(child=serializers.IntegerField(), max_depth=0) - output = field.run_validation({'a': 1, 'b': 2}) - assert output == {'a': 1, 'b': 2} - - def test_max_depth_zero_rejects_nested_dict(self): - field = serializers.DictField( - child=serializers.DictField(child=serializers.IntegerField()), - max_depth=0 - ) with pytest.raises(serializers.ValidationError) as exc_info: - field.run_validation({'a': {'b': 1}}) + field.run_validation({'a': 1, 'b': 2}) assert 'max_depth' in str(exc_info.value.detail) - def test_max_depth_one_allows_one_level_nesting(self): - field = serializers.DictField( - child=serializers.DictField(child=serializers.IntegerField()), - max_depth=1 - ) - output = field.run_validation({'a': {'b': 1}, 'c': {'d': 2}}) - assert output == {'a': {'b': 1}, 'c': {'d': 2}} + def test_max_depth_one_allows_flat_dict(self): + field = serializers.DictField(child=serializers.IntegerField(), max_depth=1) + output = field.run_validation({'a': 1, 'b': 2}) + assert output == {'a': 1, 'b': 2} def test_max_depth_one_rejects_two_levels_nesting(self): field = serializers.DictField( @@ -2681,7 +2673,7 @@ class TestDictFieldMaxDepth: def test_max_depth_with_mixed_nesting_dict_and_list(self): field = serializers.DictField( child=serializers.ListField(child=serializers.DictField(child=serializers.IntegerField())), - max_depth=2 + max_depth=3 ) output = field.run_validation({'a': [{'b': 1, 'c': 2}]}) assert output == {'a': [{'b': 1, 'c': 2}]} @@ -2710,7 +2702,7 @@ class TestDictFieldMaxDepth: class TestMaxDepthEdgeCases: def test_field_reuse_does_not_leak_depth_state(self): child_field = serializers.ListField(child=serializers.IntegerField()) - field = serializers.ListField(child=child_field, max_depth=1) + field = serializers.ListField(child=child_field, max_depth=2) output1 = field.run_validation([[1, 2], [3, 4]]) assert output1 == [[1, 2], [3, 4]] output2 = field.run_validation([[5, 6], [7, 8]]) @@ -2719,7 +2711,7 @@ class TestMaxDepthEdgeCases: def test_max_depth_with_empty_nested_structures(self): field = serializers.ListField( child=serializers.ListField(child=serializers.IntegerField()), - max_depth=1 + max_depth=2 ) output = field.run_validation([[], []]) assert output == [[], []] @@ -2739,35 +2731,34 @@ class TestMaxDepthEdgeCases: class TestMaxDepthDataInspection: def test_flat_field_rejects_deeply_nested_data(self): field = serializers.ListField(max_depth=1) - field.run_validation([[1, 2]]) + field.run_validation([1, 2, 3]) with pytest.raises(serializers.ValidationError) as exc_info: - field.run_validation([[[1]]]) + field.run_validation([[1, 2]]) assert 'max_depth' in str(exc_info.value.detail) def test_flat_dict_field_rejects_deeply_nested_data(self): field = serializers.DictField(max_depth=1) - field.run_validation({'a': {'b': 1}}) + field.run_validation({'a': 1, 'b': 2}) with pytest.raises(serializers.ValidationError) as exc_info: - field.run_validation({'a': {'b': {'c': 1}}}) + field.run_validation({'a': {'b': 1}}) assert 'max_depth' in str(exc_info.value.detail) def test_max_depth_zero_rejects_any_nesting_in_data(self): field = serializers.ListField(max_depth=0) - field.run_validation([1, 2, 3]) with pytest.raises(serializers.ValidationError): - field.run_validation([[1]]) + field.run_validation([1, 2, 3]) def test_data_depth_check_with_mixed_structures(self): field = serializers.ListField(max_depth=1) - field.run_validation([{'a': 1}, [2], 3]) + field.run_validation([1, 2, 3]) with pytest.raises(serializers.ValidationError): - field.run_validation([{'a': {'b': 1}}]) + field.run_validation([{'a': 1}, [2], 3]) def test_dict_field_data_depth_with_nested_lists(self): field = serializers.DictField(max_depth=1) - field.run_validation({'a': [1, 2], 'b': 'text'}) + field.run_validation({'a': 1, 'b': 2}) with pytest.raises(serializers.ValidationError): - field.run_validation({'a': [[1, 2]]}) + field.run_validation({'a': [1, 2]}) def test_data_depth_respects_current_depth(self): inner = serializers.ListField(child=serializers.IntegerField()) @@ -2776,13 +2767,77 @@ class TestMaxDepthDataInspection: with pytest.raises(serializers.ValidationError): outer.run_validation([[[1]]]) + def test_max_depth_one_means_flat_only(self): + field = serializers.ListField(child=serializers.IntegerField(), max_depth=1) + field.run_validation([1, 2, 3]) + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation([[1, 2]]) + assert 'max_depth' in str(exc_info.value.detail) + + def test_serializer_with_list_field_respects_depth(self): + class TestSerializer(serializers.Serializer): + data = serializers.ListField(child=serializers.IntegerField(), max_depth=1) + + serializer = TestSerializer(data={'data': [1, 2, 3]}) + assert serializer.is_valid() + + serializer = TestSerializer(data={'data': [[1, 2]]}) + assert not serializer.is_valid() + assert 'max_depth' in str(serializer.errors) + + def test_serializer_with_max_depth_inspects_raw_data(self): + class TestSerializer(serializers.Serializer): + name = serializers.CharField() + value = serializers.IntegerField() + + serializer = TestSerializer(data={'name': 'test', 'value': 1}, max_depth=1) + assert serializer.is_valid() + + serializer = TestSerializer(data={'name': 'test', 'value': {'nested': 1}}, max_depth=1) + assert not serializer.is_valid() + assert 'max_depth' in str(serializer.errors) or 'invalid' in str(serializer.errors) + + def test_standalone_serializer_protects_against_deep_json(self): + class SimpleSerializer(serializers.Serializer): + data = serializers.CharField() + + serializer = SimpleSerializer(data={'data': 'value'}, max_depth=1) + assert serializer.is_valid() + + deep_data = {'data': {'nested': {'deep': 'value'}}} + serializer = SimpleSerializer(data=deep_data, max_depth=1) + assert not serializer.is_valid() + assert 'max_depth' in str(serializer.errors) + + def test_list_serializer_many_true_respects_max_depth(self): + class MySerializer(serializers.Serializer): + name = serializers.CharField() + + serializer = MySerializer(data=[{'name': 'test'}], many=True, max_depth=1) + assert serializer.is_valid() + + deep_list_data = [{'name': 'test'}, [1, 2, 3]] + serializer = MySerializer(data=deep_list_data, many=True, max_depth=1) + assert not serializer.is_valid() + + def test_list_serializer_protects_against_deeply_nested_lists(self): + class ItemSerializer(serializers.Serializer): + value = serializers.IntegerField() + + serializer = ItemSerializer(data=[{'value': 1}, {'value': 2}], many=True, max_depth=1) + assert serializer.is_valid() + + deep_data = [{'value': {'nested': 1}}] + serializer = ItemSerializer(data=deep_data, many=True, max_depth=1) + assert not serializer.is_valid() + class TestMaxDepthWithSerializers: def test_list_field_containing_serializer_with_nested_list(self): class InnerSerializer(serializers.Serializer): numbers = serializers.ListField(child=serializers.IntegerField()) - field = serializers.ListField(child=InnerSerializer(), max_depth=2) + field = serializers.ListField(child=InnerSerializer(), max_depth=3) valid_data = [{'numbers': [1, 2]}, {'numbers': [3, 4]}] output = field.run_validation(valid_data) assert output == [{'numbers': [1, 2]}, {'numbers': [3, 4]}] @@ -2793,7 +2848,7 @@ class TestMaxDepthWithSerializers: child=serializers.ListField(child=serializers.IntegerField()) ) - field = serializers.ListField(child=InnerSerializer(), max_depth=2) + field = serializers.ListField(child=InnerSerializer(), max_depth=3) with pytest.raises(serializers.ValidationError): field.run_validation([{'nested_list': [[1, 2]]}]) @@ -2801,7 +2856,7 @@ class TestMaxDepthWithSerializers: class ValueSerializer(serializers.Serializer): data = serializers.ListField(child=serializers.IntegerField()) - field = serializers.DictField(child=ValueSerializer(), max_depth=2) + field = serializers.DictField(child=ValueSerializer(), max_depth=3) valid_data = {'key1': {'data': [1, 2]}, 'key2': {'data': [3, 4]}} output = field.run_validation(valid_data) assert output == {'key1': {'data': [1, 2]}, 'key2': {'data': [3, 4]}} @@ -2816,7 +2871,7 @@ class TestMaxDepthWithSerializers: class Level1Serializer(serializers.Serializer): level2 = Level2Serializer() - field = serializers.ListField(child=Level1Serializer(), max_depth=3) + field = serializers.ListField(child=Level1Serializer(), max_depth=4) with pytest.raises(serializers.ValidationError): field.run_validation([{'level2': {'level3': {'values': [1, 2]}}}])