From 1addd09e2b0e26507aada864123f610ead62d8da Mon Sep 17 00:00:00 2001 From: Anton Shutik Date: Thu, 27 Feb 2014 18:34:36 +0300 Subject: [PATCH 1/3] RelatedField default value handling fixed --- rest_framework/fields.py | 10 ++++++---- rest_framework/relations.py | 9 ++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 05daaab76..68b956822 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -301,6 +301,11 @@ class WritableField(Field): result.validators = self.validators[:] return result + def get_default_value(self): + if is_simple_callable(self.default): + return self.default() + return self.default + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -349,10 +354,7 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - if is_simple_callable(self.default): - native = self.default() - else: - native = self.default + native = self.get_default_value() else: if self.required: raise ValidationError(self.error_messages['required']) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 02185c2ff..626454aca 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -118,6 +118,13 @@ class RelatedField(WritableField): choices = property(_get_choices, _set_choices) + ### Default value handling + + def get_default_value(self): + default = super(RelatedField, self).get_default_value() + return default or \ + [] if self.many else None + ### Regular serializer stuff... def field_to_native(self, obj, field_name): @@ -166,7 +173,7 @@ class RelatedField(WritableField): except KeyError: if self.partial: return - value = [] if self.many else None + value = self.get_default_value() if value in (None, '') and self.required: raise ValidationError(self.error_messages['required']) From 3c62f0efc3cff7c1d7da9f13e0b0629d963069cb Mon Sep 17 00:00:00 2001 From: Anton Shutik Date: Fri, 28 Feb 2014 13:59:21 +0300 Subject: [PATCH 2/3] RelatedField.get_default_value: return empty list if self.many==True --- rest_framework/relations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 626454aca..19dc9d6e5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -122,8 +122,9 @@ class RelatedField(WritableField): def get_default_value(self): default = super(RelatedField, self).get_default_value() - return default or \ - [] if self.many else None + if self.many and default is None: + return [] + return default ### Regular serializer stuff... From dea2766abac5ef55fa226f413711cfd49af2a745 Mon Sep 17 00:00:00 2001 From: Anton Shutik Date: Tue, 4 Mar 2014 13:11:54 +0300 Subject: [PATCH 3/3] Added tests for "get_default_value" function --- rest_framework/tests/test_serializer.py | 52 +++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6b1e333e4..a20137494 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -880,6 +880,58 @@ class DefaultValueTests(TestCase): self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + + def setUp(self): + self.expected = {'default': 'value'} + self.create_field = fields.WritableField + + def test_get_default_value_with_noncallable(self): + field = self.create_field(default=self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_with_callable(self): + field = self.create_field(default=lambda : self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_when_not_required(self): + field = self.create_field(default=self.expected, required=False) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_returns_None(self): + field = self.create_field() + got = field.get_default_value() + self.assertIsNone(got) + + def test_get_default_value_returns_non_True_values(self): + values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause + for expected in values: + field = self.create_field(default=expected) + got = field.get_default_value() + self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + + def setUp(self): + self.expected = {'foo': 'bar'} + self.create_field = relations.RelatedField + + def test_get_default_value_returns_empty_list(self): + field = self.create_field(many=True) + got = field.get_default_value() + self.assertListEqual(got, []) + + def test_get_default_value_returns_expected(self): + expected = [1, 2, 3] + field = self.create_field(many=True, default=expected) + got = field.get_default_value() + self.assertListEqual(got, expected) + + class CallableDefaultValueTests(TestCase): def setUp(self): class CallableDefaultValueSerializer(serializers.ModelSerializer):