Add max_depth parameter to prevent DoS from deeply nested data

This commit is contained in:
Mahdi 2025-12-30 19:21:47 +03:30
parent acc3fd726a
commit 6ad1096656
3 changed files with 201 additions and 91 deletions

View File

@ -1692,22 +1692,14 @@ class ListField(Field):
if hasattr(self.child, '_propagate_depth_to_child'):
self.child._propagate_depth_to_child()
def _check_data_depth(self, data, current=0):
if self._root_max_depth is not None:
if isinstance(data, (list, tuple)):
for item in data:
if isinstance(item, (list, tuple, dict)):
next_depth = current + 1
if next_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(item, next_depth)
elif isinstance(data, dict):
for value in data.values():
if isinstance(value, (list, tuple, dict)):
next_depth = current + 1
if next_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(value, next_depth)
def _check_data_depth(self, data, current_level):
items = data.values() if isinstance(data, dict) else data
for item in items:
if isinstance(item, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(item, next_level)
def get_value(self, dictionary):
if self.field_name not in dictionary:
@ -1734,9 +1726,12 @@ class ListField(Field):
self.fail('not_a_list', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
if self._root_max_depth is not None and self._current_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(data, self._current_depth)
if self._root_max_depth is not None:
start_level = self._current_depth if self._current_depth > 0 else 1
if start_level > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
if self.max_depth is not None:
self._check_data_depth(data, start_level)
return self.run_child_validation(data)
def to_representation(self, data):
@ -1789,7 +1784,7 @@ class DictField(Field):
def bind(self, field_name, parent):
super().bind(field_name, parent)
if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
self._root_max_depth = parent._root_max_depth
self._current_depth = parent._current_depth + 1
self._propagate_depth_to_child()
@ -1801,22 +1796,14 @@ class DictField(Field):
if hasattr(self.child, '_propagate_depth_to_child'):
self.child._propagate_depth_to_child()
def _check_data_depth(self, data, current=0):
if self._root_max_depth is not None:
if isinstance(data, dict):
for value in data.values():
if isinstance(value, (list, tuple, dict)):
next_depth = current + 1
if next_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(value, next_depth)
elif isinstance(data, (list, tuple)):
for item in data:
if isinstance(item, (list, tuple, dict)):
next_depth = current + 1
if next_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(item, next_depth)
def _check_data_depth(self, data, current_level):
items = data.values() if isinstance(data, dict) else data
for item in items:
if isinstance(item, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(item, next_level)
def get_value(self, dictionary):
# We override the default field access in order to support
@ -1835,9 +1822,12 @@ class DictField(Field):
self.fail('not_a_dict', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
if self._root_max_depth is not None and self._current_depth > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
self._check_data_depth(data, self._current_depth)
if self._root_max_depth is not None:
start_level = self._current_depth if self._current_depth > 0 else 1
if start_level > self._root_max_depth:
self.fail('max_depth', max_depth=self._root_max_depth)
if self.max_depth is not None:
self._check_data_depth(data, start_level)
return self.run_child_validation(data)
def to_representation(self, value):

View File

@ -76,9 +76,9 @@ LIST_SERIALIZER_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
'instance', 'data', 'partial', 'context', 'allow_null',
'max_length', 'min_length'
'max_length', 'min_length', 'max_depth'
)
LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length')
LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length', 'max_depth')
ALL_FIELDS = '__all__'
@ -111,6 +111,10 @@ class BaseSerializer(Field):
.data - Available.
"""
default_error_messages = {
'max_depth': _('Nesting depth exceeds maximum allowed depth of {max_depth}.')
}
def __init__(self, instance=None, data=empty, **kwargs):
self.instance = instance
if data is not empty:
@ -118,13 +122,14 @@ class BaseSerializer(Field):
self.partial = kwargs.pop('partial', False)
self._context = kwargs.pop('context', {})
kwargs.pop('many', None)
self.max_depth = kwargs.pop('max_depth', None)
super().__init__(**kwargs)
self._current_depth = 0
self._root_max_depth = None
self._root_max_depth = self.max_depth
def bind(self, field_name, parent):
super().bind(field_name, parent)
if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
self._root_max_depth = parent._root_max_depth
self._current_depth = parent._current_depth + 1
@ -137,6 +142,32 @@ class BaseSerializer(Field):
if hasattr(field, '_propagate_depth_to_child'):
field._propagate_depth_to_child()
def _check_data_depth(self, data, current_level):
if isinstance(data, dict):
for value in data.values():
if isinstance(value, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
message = self.error_messages['max_depth'].format(
max_depth=self._root_max_depth
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='max_depth')
self._check_data_depth(value, next_level)
elif isinstance(data, (list, tuple)):
for item in data:
if isinstance(item, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
message = self.error_messages['max_depth'].format(
max_depth=self._root_max_depth
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='max_depth')
self._check_data_depth(item, next_level)
def __new__(cls, *args, **kwargs):
# We override this method in order to automatically create
# `ListSerializer` classes instead when `many=True` is set.
@ -390,7 +421,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
fields = BindingDict(self)
for key, value in self.get_fields().items():
fields[key] = value
self._propagate_depth_to_child()
if self._root_max_depth is not None:
self._propagate_depth_to_child()
return fields
@property
@ -507,6 +539,9 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid')
if self._root_max_depth is not None and self.max_depth is not None:
start_level = self._current_depth
self._check_data_depth(data, start_level)
ret = {}
errors = {}
@ -672,6 +707,32 @@ class ListSerializer(BaseSerializer):
"""
return self.child.run_validation(data)
def _check_data_depth(self, data, current_level):
if isinstance(data, (list, tuple)):
for item in data:
if isinstance(item, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
message = self.error_messages['max_depth'].format(
max_depth=self._root_max_depth
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='max_depth')
self._check_data_depth(item, next_level)
elif isinstance(data, dict):
for value in data.values():
if isinstance(value, (list, tuple, dict)):
next_level = current_level + 1
if next_level > self._root_max_depth:
message = self.error_messages['max_depth'].format(
max_depth=self._root_max_depth
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='max_depth')
self._check_data_depth(value, next_level)
def to_internal_value(self, data):
"""
List of dicts of native values <- List of dicts of primitive datatypes.
@ -687,6 +748,10 @@ class ListSerializer(BaseSerializer):
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='not_a_list')
if self._root_max_depth is not None and self.max_depth is not None:
start_level = self._current_depth
self._check_data_depth(data, start_level)
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
raise ValidationError({

View File

@ -2559,32 +2559,35 @@ class TestListFieldMaxDepth:
output = field.run_validation([[[1, 2]], [[3]]])
assert output == [[[1, 2]], [[3]]]
def test_max_depth_zero_allows_field_itself(self):
def test_max_depth_zero_rejects_everything(self):
field = serializers.ListField(child=serializers.IntegerField(), max_depth=0)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation([1, 2, 3])
assert 'max_depth' in str(exc_info.value.detail)
def test_max_depth_one_allows_flat_list(self):
field = serializers.ListField(child=serializers.IntegerField(), max_depth=1)
output = field.run_validation([1, 2, 3])
assert output == [1, 2, 3]
def test_max_depth_zero_rejects_nested_list(self):
field = serializers.ListField(
child=serializers.ListField(child=serializers.IntegerField()),
max_depth=0
)
def test_max_depth_one_rejects_nested_list(self):
field = serializers.ListField(child=serializers.IntegerField(), max_depth=1)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation([[1, 2], [3]])
field.run_validation([[1, 2]])
assert 'max_depth' in str(exc_info.value.detail)
def test_max_depth_one_allows_one_level_nesting(self):
def test_max_depth_two_allows_one_level_nesting(self):
field = serializers.ListField(
child=serializers.ListField(child=serializers.IntegerField()),
max_depth=1
max_depth=2
)
output = field.run_validation([[1, 2], [3, 4]])
assert output == [[1, 2], [3, 4]]
def test_max_depth_one_rejects_two_levels_nesting(self):
def test_max_depth_two_rejects_two_levels_nesting(self):
field = serializers.ListField(
child=serializers.ListField(child=serializers.ListField(child=serializers.IntegerField())),
max_depth=1
child=serializers.ListField(child=serializers.IntegerField()),
max_depth=2
)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation([[[1, 2]], [[3]]])
@ -2607,7 +2610,7 @@ class TestListFieldMaxDepth:
def test_max_depth_with_mixed_nesting_list_and_dict(self):
field = serializers.ListField(
child=serializers.DictField(child=serializers.ListField(child=serializers.IntegerField())),
max_depth=2
max_depth=3
)
output = field.run_validation([{'a': [1, 2], 'b': [3]}])
assert output == [{'a': [1, 2], 'b': [3]}]
@ -2647,27 +2650,16 @@ class TestDictFieldMaxDepth:
output = field.run_validation({'a': {'b': {'c': 1}}})
assert output == {'a': {'b': {'c': 1}}}
def test_max_depth_zero_allows_field_itself(self):
def test_max_depth_zero_rejects_everything(self):
field = serializers.DictField(child=serializers.IntegerField(), max_depth=0)
output = field.run_validation({'a': 1, 'b': 2})
assert output == {'a': 1, 'b': 2}
def test_max_depth_zero_rejects_nested_dict(self):
field = serializers.DictField(
child=serializers.DictField(child=serializers.IntegerField()),
max_depth=0
)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation({'a': {'b': 1}})
field.run_validation({'a': 1, 'b': 2})
assert 'max_depth' in str(exc_info.value.detail)
def test_max_depth_one_allows_one_level_nesting(self):
field = serializers.DictField(
child=serializers.DictField(child=serializers.IntegerField()),
max_depth=1
)
output = field.run_validation({'a': {'b': 1}, 'c': {'d': 2}})
assert output == {'a': {'b': 1}, 'c': {'d': 2}}
def test_max_depth_one_allows_flat_dict(self):
field = serializers.DictField(child=serializers.IntegerField(), max_depth=1)
output = field.run_validation({'a': 1, 'b': 2})
assert output == {'a': 1, 'b': 2}
def test_max_depth_one_rejects_two_levels_nesting(self):
field = serializers.DictField(
@ -2681,7 +2673,7 @@ class TestDictFieldMaxDepth:
def test_max_depth_with_mixed_nesting_dict_and_list(self):
field = serializers.DictField(
child=serializers.ListField(child=serializers.DictField(child=serializers.IntegerField())),
max_depth=2
max_depth=3
)
output = field.run_validation({'a': [{'b': 1, 'c': 2}]})
assert output == {'a': [{'b': 1, 'c': 2}]}
@ -2710,7 +2702,7 @@ class TestDictFieldMaxDepth:
class TestMaxDepthEdgeCases:
def test_field_reuse_does_not_leak_depth_state(self):
child_field = serializers.ListField(child=serializers.IntegerField())
field = serializers.ListField(child=child_field, max_depth=1)
field = serializers.ListField(child=child_field, max_depth=2)
output1 = field.run_validation([[1, 2], [3, 4]])
assert output1 == [[1, 2], [3, 4]]
output2 = field.run_validation([[5, 6], [7, 8]])
@ -2719,7 +2711,7 @@ class TestMaxDepthEdgeCases:
def test_max_depth_with_empty_nested_structures(self):
field = serializers.ListField(
child=serializers.ListField(child=serializers.IntegerField()),
max_depth=1
max_depth=2
)
output = field.run_validation([[], []])
assert output == [[], []]
@ -2739,35 +2731,34 @@ class TestMaxDepthEdgeCases:
class TestMaxDepthDataInspection:
def test_flat_field_rejects_deeply_nested_data(self):
field = serializers.ListField(max_depth=1)
field.run_validation([[1, 2]])
field.run_validation([1, 2, 3])
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation([[[1]]])
field.run_validation([[1, 2]])
assert 'max_depth' in str(exc_info.value.detail)
def test_flat_dict_field_rejects_deeply_nested_data(self):
field = serializers.DictField(max_depth=1)
field.run_validation({'a': {'b': 1}})
field.run_validation({'a': 1, 'b': 2})
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation({'a': {'b': {'c': 1}}})
field.run_validation({'a': {'b': 1}})
assert 'max_depth' in str(exc_info.value.detail)
def test_max_depth_zero_rejects_any_nesting_in_data(self):
field = serializers.ListField(max_depth=0)
field.run_validation([1, 2, 3])
with pytest.raises(serializers.ValidationError):
field.run_validation([[1]])
field.run_validation([1, 2, 3])
def test_data_depth_check_with_mixed_structures(self):
field = serializers.ListField(max_depth=1)
field.run_validation([{'a': 1}, [2], 3])
field.run_validation([1, 2, 3])
with pytest.raises(serializers.ValidationError):
field.run_validation([{'a': {'b': 1}}])
field.run_validation([{'a': 1}, [2], 3])
def test_dict_field_data_depth_with_nested_lists(self):
field = serializers.DictField(max_depth=1)
field.run_validation({'a': [1, 2], 'b': 'text'})
field.run_validation({'a': 1, 'b': 2})
with pytest.raises(serializers.ValidationError):
field.run_validation({'a': [[1, 2]]})
field.run_validation({'a': [1, 2]})
def test_data_depth_respects_current_depth(self):
inner = serializers.ListField(child=serializers.IntegerField())
@ -2776,13 +2767,77 @@ class TestMaxDepthDataInspection:
with pytest.raises(serializers.ValidationError):
outer.run_validation([[[1]]])
def test_max_depth_one_means_flat_only(self):
field = serializers.ListField(child=serializers.IntegerField(), max_depth=1)
field.run_validation([1, 2, 3])
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation([[1, 2]])
assert 'max_depth' in str(exc_info.value.detail)
def test_serializer_with_list_field_respects_depth(self):
class TestSerializer(serializers.Serializer):
data = serializers.ListField(child=serializers.IntegerField(), max_depth=1)
serializer = TestSerializer(data={'data': [1, 2, 3]})
assert serializer.is_valid()
serializer = TestSerializer(data={'data': [[1, 2]]})
assert not serializer.is_valid()
assert 'max_depth' in str(serializer.errors)
def test_serializer_with_max_depth_inspects_raw_data(self):
class TestSerializer(serializers.Serializer):
name = serializers.CharField()
value = serializers.IntegerField()
serializer = TestSerializer(data={'name': 'test', 'value': 1}, max_depth=1)
assert serializer.is_valid()
serializer = TestSerializer(data={'name': 'test', 'value': {'nested': 1}}, max_depth=1)
assert not serializer.is_valid()
assert 'max_depth' in str(serializer.errors) or 'invalid' in str(serializer.errors)
def test_standalone_serializer_protects_against_deep_json(self):
class SimpleSerializer(serializers.Serializer):
data = serializers.CharField()
serializer = SimpleSerializer(data={'data': 'value'}, max_depth=1)
assert serializer.is_valid()
deep_data = {'data': {'nested': {'deep': 'value'}}}
serializer = SimpleSerializer(data=deep_data, max_depth=1)
assert not serializer.is_valid()
assert 'max_depth' in str(serializer.errors)
def test_list_serializer_many_true_respects_max_depth(self):
class MySerializer(serializers.Serializer):
name = serializers.CharField()
serializer = MySerializer(data=[{'name': 'test'}], many=True, max_depth=1)
assert serializer.is_valid()
deep_list_data = [{'name': 'test'}, [1, 2, 3]]
serializer = MySerializer(data=deep_list_data, many=True, max_depth=1)
assert not serializer.is_valid()
def test_list_serializer_protects_against_deeply_nested_lists(self):
class ItemSerializer(serializers.Serializer):
value = serializers.IntegerField()
serializer = ItemSerializer(data=[{'value': 1}, {'value': 2}], many=True, max_depth=1)
assert serializer.is_valid()
deep_data = [{'value': {'nested': 1}}]
serializer = ItemSerializer(data=deep_data, many=True, max_depth=1)
assert not serializer.is_valid()
class TestMaxDepthWithSerializers:
def test_list_field_containing_serializer_with_nested_list(self):
class InnerSerializer(serializers.Serializer):
numbers = serializers.ListField(child=serializers.IntegerField())
field = serializers.ListField(child=InnerSerializer(), max_depth=2)
field = serializers.ListField(child=InnerSerializer(), max_depth=3)
valid_data = [{'numbers': [1, 2]}, {'numbers': [3, 4]}]
output = field.run_validation(valid_data)
assert output == [{'numbers': [1, 2]}, {'numbers': [3, 4]}]
@ -2793,7 +2848,7 @@ class TestMaxDepthWithSerializers:
child=serializers.ListField(child=serializers.IntegerField())
)
field = serializers.ListField(child=InnerSerializer(), max_depth=2)
field = serializers.ListField(child=InnerSerializer(), max_depth=3)
with pytest.raises(serializers.ValidationError):
field.run_validation([{'nested_list': [[1, 2]]}])
@ -2801,7 +2856,7 @@ class TestMaxDepthWithSerializers:
class ValueSerializer(serializers.Serializer):
data = serializers.ListField(child=serializers.IntegerField())
field = serializers.DictField(child=ValueSerializer(), max_depth=2)
field = serializers.DictField(child=ValueSerializer(), max_depth=3)
valid_data = {'key1': {'data': [1, 2]}, 'key2': {'data': [3, 4]}}
output = field.run_validation(valid_data)
assert output == {'key1': {'data': [1, 2]}, 'key2': {'data': [3, 4]}}
@ -2816,7 +2871,7 @@ class TestMaxDepthWithSerializers:
class Level1Serializer(serializers.Serializer):
level2 = Level2Serializer()
field = serializers.ListField(child=Level1Serializer(), max_depth=3)
field = serializers.ListField(child=Level1Serializer(), max_depth=4)
with pytest.raises(serializers.ValidationError):
field.run_validation([{'level2': {'level3': {'values': [1, 2]}}}])