diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 05daaab76..68b956822 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -301,6 +301,11 @@ class WritableField(Field): result.validators = self.validators[:] return result + def get_default_value(self): + if is_simple_callable(self.default): + return self.default() + return self.default + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -349,10 +354,7 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - if is_simple_callable(self.default): - native = self.default() - else: - native = self.default + native = self.get_default_value() else: if self.required: raise ValidationError(self.error_messages['required']) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 163a89840..308545ce9 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -119,6 +119,14 @@ class RelatedField(WritableField): choices = property(_get_choices, _set_choices) + ### Default value handling + + def get_default_value(self): + default = super(RelatedField, self).get_default_value() + if self.many and default is None: + return [] + return default + ### Regular serializer stuff... def field_to_native(self, obj, field_name): @@ -167,7 +175,7 @@ class RelatedField(WritableField): except KeyError: if self.partial: return - value = [] if self.many else None + value = self.get_default_value() if value in self.null_values: if self.required: diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 47082190f..198c269f0 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -900,6 +900,58 @@ class DefaultValueTests(TestCase): self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + + def setUp(self): + self.expected = {'default': 'value'} + self.create_field = fields.WritableField + + def test_get_default_value_with_noncallable(self): + field = self.create_field(default=self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_with_callable(self): + field = self.create_field(default=lambda : self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_when_not_required(self): + field = self.create_field(default=self.expected, required=False) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_returns_None(self): + field = self.create_field() + got = field.get_default_value() + self.assertIsNone(got) + + def test_get_default_value_returns_non_True_values(self): + values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause + for expected in values: + field = self.create_field(default=expected) + got = field.get_default_value() + self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + + def setUp(self): + self.expected = {'foo': 'bar'} + self.create_field = relations.RelatedField + + def test_get_default_value_returns_empty_list(self): + field = self.create_field(many=True) + got = field.get_default_value() + self.assertListEqual(got, []) + + def test_get_default_value_returns_expected(self): + expected = [1, 2, 3] + field = self.create_field(many=True, default=expected) + got = field.get_default_value() + self.assertListEqual(got, expected) + + class CallableDefaultValueTests(TestCase): def setUp(self): class CallableDefaultValueSerializer(serializers.ModelSerializer):