mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-02 20:54:42 +03:00
Merge pull request #1442 from Anton-Shutik/master
RelatedField default value handling fixed
This commit is contained in:
commit
4edd39b2e4
|
@ -301,6 +301,11 @@ class WritableField(Field):
|
||||||
result.validators = self.validators[:]
|
result.validators = self.validators[:]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_default_value(self):
|
||||||
|
if is_simple_callable(self.default):
|
||||||
|
return self.default()
|
||||||
|
return self.default
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
if value in validators.EMPTY_VALUES and self.required:
|
if value in validators.EMPTY_VALUES and self.required:
|
||||||
raise ValidationError(self.error_messages['required'])
|
raise ValidationError(self.error_messages['required'])
|
||||||
|
@ -349,10 +354,7 @@ class WritableField(Field):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if self.default is not None and not self.partial:
|
if self.default is not None and not self.partial:
|
||||||
# Note: partial updates shouldn't set defaults
|
# Note: partial updates shouldn't set defaults
|
||||||
if is_simple_callable(self.default):
|
native = self.get_default_value()
|
||||||
native = self.default()
|
|
||||||
else:
|
|
||||||
native = self.default
|
|
||||||
else:
|
else:
|
||||||
if self.required:
|
if self.required:
|
||||||
raise ValidationError(self.error_messages['required'])
|
raise ValidationError(self.error_messages['required'])
|
||||||
|
|
|
@ -119,6 +119,14 @@ class RelatedField(WritableField):
|
||||||
|
|
||||||
choices = property(_get_choices, _set_choices)
|
choices = property(_get_choices, _set_choices)
|
||||||
|
|
||||||
|
### Default value handling
|
||||||
|
|
||||||
|
def get_default_value(self):
|
||||||
|
default = super(RelatedField, self).get_default_value()
|
||||||
|
if self.many and default is None:
|
||||||
|
return []
|
||||||
|
return default
|
||||||
|
|
||||||
### Regular serializer stuff...
|
### Regular serializer stuff...
|
||||||
|
|
||||||
def field_to_native(self, obj, field_name):
|
def field_to_native(self, obj, field_name):
|
||||||
|
@ -167,7 +175,7 @@ class RelatedField(WritableField):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if self.partial:
|
if self.partial:
|
||||||
return
|
return
|
||||||
value = [] if self.many else None
|
value = self.get_default_value()
|
||||||
|
|
||||||
if value in self.null_values:
|
if value in self.null_values:
|
||||||
if self.required:
|
if self.required:
|
||||||
|
|
|
@ -900,6 +900,58 @@ class DefaultValueTests(TestCase):
|
||||||
self.assertEqual(instance.text, 'overridden')
|
self.assertEqual(instance.text, 'overridden')
|
||||||
|
|
||||||
|
|
||||||
|
class WritableFieldDefaultValueTests(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.expected = {'default': 'value'}
|
||||||
|
self.create_field = fields.WritableField
|
||||||
|
|
||||||
|
def test_get_default_value_with_noncallable(self):
|
||||||
|
field = self.create_field(default=self.expected)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertEqual(got, self.expected)
|
||||||
|
|
||||||
|
def test_get_default_value_with_callable(self):
|
||||||
|
field = self.create_field(default=lambda : self.expected)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertEqual(got, self.expected)
|
||||||
|
|
||||||
|
def test_get_default_value_when_not_required(self):
|
||||||
|
field = self.create_field(default=self.expected, required=False)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertEqual(got, self.expected)
|
||||||
|
|
||||||
|
def test_get_default_value_returns_None(self):
|
||||||
|
field = self.create_field()
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertIsNone(got)
|
||||||
|
|
||||||
|
def test_get_default_value_returns_non_True_values(self):
|
||||||
|
values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
|
||||||
|
for expected in values:
|
||||||
|
field = self.create_field(default=expected)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertEqual(got, expected)
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.expected = {'foo': 'bar'}
|
||||||
|
self.create_field = relations.RelatedField
|
||||||
|
|
||||||
|
def test_get_default_value_returns_empty_list(self):
|
||||||
|
field = self.create_field(many=True)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertListEqual(got, [])
|
||||||
|
|
||||||
|
def test_get_default_value_returns_expected(self):
|
||||||
|
expected = [1, 2, 3]
|
||||||
|
field = self.create_field(many=True, default=expected)
|
||||||
|
got = field.get_default_value()
|
||||||
|
self.assertListEqual(got, expected)
|
||||||
|
|
||||||
|
|
||||||
class CallableDefaultValueTests(TestCase):
|
class CallableDefaultValueTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
class CallableDefaultValueSerializer(serializers.ModelSerializer):
|
class CallableDefaultValueSerializer(serializers.ModelSerializer):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user