From 214ce0777d3832e1f07ff45374ce4f4b68379b3a Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 3 Feb 2014 21:16:51 +0100 Subject: [PATCH] Raise error for dotted sources on creation. --- rest_framework/fields.py | 1 + rest_framework/serializers.py | 29 ++++++++++++++++--- rest_framework/tests/test_serializer.py | 38 +++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 2f475d6ea..7c64af5c1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -251,6 +251,7 @@ class WritableField(Field): default_error_messages = { 'required': _('This field is required.'), 'invalid': _('Invalid value.'), + 'missing': _('Related object does not exist.'), } widget = widgets.TextInput default = None diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 91436701e..7838efbf8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict +from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import get_concrete_model, six from rest_framework.settings import api_settings @@ -919,6 +920,7 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + related_models_to_save = {} nested_forward_relations = {} meta = self.opts.model._meta @@ -948,11 +950,22 @@ class ModelSerializer(Serializer): # Update an existing instance... if instance is not None: for key, val in attrs.items(): - # TODO: check whether we need to do something about dotted paths + # Follow relations if needed dest = instance keys = key.split('.') - for related_instance in keys[:-1]: - dest = getattr(dest, related_instance) + try: + for related_instance in keys[:-1]: + dest = getattr(dest, related_instance) + + except AttributeError: + self._errors[key] = self.error_messages['missing'] + continue + + # If there's a relation, mark the object to save + if len(keys) > 1: + related_models_to_save[key] = dest + + # Assign the value try: setattr(dest, keys[-1], val) except ValueError: @@ -960,8 +973,11 @@ class ModelSerializer(Serializer): # ...or create a new instance else: + for key, value in ((k, v) for k, v in attrs.items() if '.' in k): + self._errors[key] = 'You can not set dotted sources during creation.' + del attrs[key] + print value instance = self.opts.model(**attrs) - # TODO: check whether we need to do something about dotted paths # Any relations that cannot be set until we've # saved the model get hidden away on these @@ -970,6 +986,7 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations + instance._related_models_to_save = related_models_to_save return instance @@ -1028,6 +1045,10 @@ class ModelSerializer(Serializer): setattr(obj, accessor_name, related) del(obj._related_data) + if getattr(obj, '_related_models_to_save', None): + for related in obj._related_models_to_save.values(): + self.save_object(related) + class HyperlinkedModelSerializerOptions(ModelSerializerOptions): """ diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index ac74baa4c..cb0a6aa7b 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1843,7 +1843,8 @@ class BoolenFieldTypeTest(TestCase): class RelationSpanningSerializerTest(TestCase): - def test_regular_field_can_span_a_relation(self): + def test_model_traversal_creation(self): + """Update a field through a foreign key during a creation.""" class TicketSerializer(serializers.ModelSerializer): name = fields.CharField(source='assigned.name') @@ -1851,9 +1852,40 @@ class RelationSpanningSerializerTest(TestCase): model = Ticket fields = ('name',) + serializer = TicketSerializer(data={'name': 'doe'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'name': 'You can not set dotted sources during creation.'}) + + def test_model_traversal_update(self): + """Update a field through a foreign key during an update.""" + class TicketSerializer(serializers.ModelSerializer): + name = fields.CharField(source='assigned.name') + + class Meta: + model = Ticket + fields = ('name',) + + owner = Person.objects.create(name='john') + reviewer = Person.objects.create(name='reviewer') + ticket = Ticket.objects.create(assigned=owner, reviewer=reviewer) + serializer = TicketSerializer(ticket, data={'name': 'doe'}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.assigned.name, 'doe') + serializer.save() + self.assertEqual(Person.objects.get(id=owner.id).name, 'doe') + + def test_failing_model_traversal(self): + """Update a field through an unknown relation.""" + class TicketSerializer(serializers.ModelSerializer): + name = fields.CharField(source='demo.name') + + class Meta: + model = Ticket + fields = ('name',) + owner = Person(name='john') reviewer = Person(name='reviewer') ticket = Ticket(assigned=owner, reviewer=reviewer) serializer = TicketSerializer(ticket, data={'name': 'doe'}) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.object.assigned.name, 'doe') + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'name': 'Related object does not exist.'})