Raise error for dotted sources on creation.

This commit is contained in:
Xavier Ordoquy 2014-02-03 21:16:51 +01:00
parent d5dd3772e9
commit 214ce0777d
3 changed files with 61 additions and 7 deletions

View File

@ -251,6 +251,7 @@ class WritableField(Field):
default_error_messages = { default_error_messages = {
'required': _('This field is required.'), 'required': _('This field is required.'),
'invalid': _('Invalid value.'), 'invalid': _('Invalid value.'),
'missing': _('Related object does not exist.'),
} }
widget = widgets.TextInput widget = widgets.TextInput
default = None default = None

View File

@ -20,6 +20,7 @@ from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets from django.forms import widgets
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import get_concrete_model, six from rest_framework.compat import get_concrete_model, six
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -919,6 +920,7 @@ class ModelSerializer(Serializer):
""" """
m2m_data = {} m2m_data = {}
related_data = {} related_data = {}
related_models_to_save = {}
nested_forward_relations = {} nested_forward_relations = {}
meta = self.opts.model._meta meta = self.opts.model._meta
@ -948,11 +950,22 @@ class ModelSerializer(Serializer):
# Update an existing instance... # Update an existing instance...
if instance is not None: if instance is not None:
for key, val in attrs.items(): for key, val in attrs.items():
# TODO: check whether we need to do something about dotted paths # Follow relations if needed
dest = instance dest = instance
keys = key.split('.') keys = key.split('.')
try:
for related_instance in keys[:-1]: for related_instance in keys[:-1]:
dest = getattr(dest, related_instance) dest = getattr(dest, related_instance)
except AttributeError:
self._errors[key] = self.error_messages['missing']
continue
# If there's a relation, mark the object to save
if len(keys) > 1:
related_models_to_save[key] = dest
# Assign the value
try: try:
setattr(dest, keys[-1], val) setattr(dest, keys[-1], val)
except ValueError: except ValueError:
@ -960,8 +973,11 @@ class ModelSerializer(Serializer):
# ...or create a new instance # ...or create a new instance
else: else:
for key, value in ((k, v) for k, v in attrs.items() if '.' in k):
self._errors[key] = 'You can not set dotted sources during creation.'
del attrs[key]
print value
instance = self.opts.model(**attrs) instance = self.opts.model(**attrs)
# TODO: check whether we need to do something about dotted paths
# Any relations that cannot be set until we've # Any relations that cannot be set until we've
# saved the model get hidden away on these # saved the model get hidden away on these
@ -970,6 +986,7 @@ class ModelSerializer(Serializer):
instance._related_data = related_data instance._related_data = related_data
instance._m2m_data = m2m_data instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations instance._nested_forward_relations = nested_forward_relations
instance._related_models_to_save = related_models_to_save
return instance return instance
@ -1028,6 +1045,10 @@ class ModelSerializer(Serializer):
setattr(obj, accessor_name, related) setattr(obj, accessor_name, related)
del(obj._related_data) del(obj._related_data)
if getattr(obj, '_related_models_to_save', None):
for related in obj._related_models_to_save.values():
self.save_object(related)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
""" """

View File

@ -1843,7 +1843,8 @@ class BoolenFieldTypeTest(TestCase):
class RelationSpanningSerializerTest(TestCase): class RelationSpanningSerializerTest(TestCase):
def test_regular_field_can_span_a_relation(self): def test_model_traversal_creation(self):
"""Update a field through a foreign key during a creation."""
class TicketSerializer(serializers.ModelSerializer): class TicketSerializer(serializers.ModelSerializer):
name = fields.CharField(source='assigned.name') name = fields.CharField(source='assigned.name')
@ -1851,9 +1852,40 @@ class RelationSpanningSerializerTest(TestCase):
model = Ticket model = Ticket
fields = ('name',) fields = ('name',)
serializer = TicketSerializer(data={'name': 'doe'})
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'name': 'You can not set dotted sources during creation.'})
def test_model_traversal_update(self):
"""Update a field through a foreign key during an update."""
class TicketSerializer(serializers.ModelSerializer):
name = fields.CharField(source='assigned.name')
class Meta:
model = Ticket
fields = ('name',)
owner = Person.objects.create(name='john')
reviewer = Person.objects.create(name='reviewer')
ticket = Ticket.objects.create(assigned=owner, reviewer=reviewer)
serializer = TicketSerializer(ticket, data={'name': 'doe'})
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.assigned.name, 'doe')
serializer.save()
self.assertEqual(Person.objects.get(id=owner.id).name, 'doe')
def test_failing_model_traversal(self):
"""Update a field through an unknown relation."""
class TicketSerializer(serializers.ModelSerializer):
name = fields.CharField(source='demo.name')
class Meta:
model = Ticket
fields = ('name',)
owner = Person(name='john') owner = Person(name='john')
reviewer = Person(name='reviewer') reviewer = Person(name='reviewer')
ticket = Ticket(assigned=owner, reviewer=reviewer) ticket = Ticket(assigned=owner, reviewer=reviewer)
serializer = TicketSerializer(ticket, data={'name': 'doe'}) serializer = TicketSerializer(ticket, data={'name': 'doe'})
self.assertTrue(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.object.assigned.name, 'doe') self.assertEqual(serializer.errors, {'name': 'Related object does not exist.'})