diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6caae9242..cda1748d0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -256,6 +256,7 @@ class WritableField(Field): default_error_messages = { 'required': _('This field is required.'), 'invalid': _('Invalid value.'), + 'missing': _('Related object does not exist.'), } widget = widgets.TextInput default = None diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 43d339da0..606b115ab 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -21,6 +21,7 @@ from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict +from django.utils.translation import ugettext_lazy as _ from django.core.exceptions import ObjectDoesNotExist from rest_framework.compat import get_concrete_model, six from rest_framework.settings import api_settings @@ -941,6 +942,7 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + related_models_to_save = {} nested_forward_relations = {} meta = self.opts.model._meta @@ -969,13 +971,33 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] - # Create an empty instance of the model + # Update an existing instance... if instance is None: instance = self.opts.model() for key, val in attrs.items(): + keys = key.split('.') + + # Work on the current instance for this attribute + attr_instance = instance + + # Raise an error if we span more than one relation + if len(keys) > 2: + self._errors[key] = 'Can not span more than a relation.' + continue + + # Mark the related instance as the one to save + if len(keys) == 2: + try: + attr_instance = getattr(instance, keys[0]) + related_models_to_save[key] = attr_instance + except (AttributeError, ObjectDoesNotExist): + self._errors[key] = self.error_messages['missing'] + continue + + # Assign the value try: - setattr(instance, key, val) + setattr(attr_instance, keys[-1], val) except ValueError: self._errors[key] = [self.error_messages['required']] @@ -986,6 +1008,7 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations + instance._related_models_to_save = related_models_to_save return instance @@ -1044,6 +1067,10 @@ class ModelSerializer(Serializer): setattr(obj, accessor_name, related) del(obj._related_data) + if getattr(obj, '_related_models_to_save', None): + for related in obj._related_models_to_save.values(): + self.save_object(related) + class HyperlinkedModelSerializerOptions(ModelSerializerOptions): """ diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index fb2eac0ba..e7cbcedbe 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1974,3 +1974,68 @@ class BoolenFieldTypeTest(TestCase): ''' bfield = self.serializer.get_fields()['started'] self.assertEqual(type(bfield), fields.BooleanField) + + +class RelationSpanningSerializerTest(TestCase): + def test_model_traversal_creation(self): + """Update a field through a foreign key during a creation.""" + class TicketSerializer(serializers.ModelSerializer): + username = fields.CharField(source='assigned.name') + + class Meta: + model = Ticket + fields = ('username',) + + serializer = TicketSerializer(data={'username': 'doe'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'username': 'You can not set dotted sources during creation.'}) + + def test_model_traversal_update(self): + """Update a field through a foreign key during an update.""" + class TicketSerializer(serializers.ModelSerializer): + username = fields.CharField(source='assigned.name') + + class Meta: + model = Ticket + fields = ('username',) + + owner = Person.objects.create(name='john') + reviewer = Person.objects.create(name='reviewer') + ticket = Ticket.objects.create(assigned=owner, reviewer=reviewer) + serializer = TicketSerializer(ticket, data={'username': 'doe'}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.assigned.name, 'doe') + serializer.save() + self.assertEqual(Person.objects.get(id=owner.id).name, 'doe') + + def test_failing_model_traversal(self): + """Update a field through an unknown relation.""" + class TicketSerializer(serializers.ModelSerializer): + username = fields.CharField(source='demo.name') + + class Meta: + model = Ticket + fields = ('username',) + + owner = Person(name='john') + reviewer = Person(name='reviewer') + ticket = Ticket(assigned=owner, reviewer=reviewer) + serializer = TicketSerializer(ticket, data={'username': 'doe'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'username': 'Related object does not exist.'}) + + def test_multiple_model_traversal_update(self): + """Update a field through a foreign key during an update.""" + class TicketSerializer(serializers.ModelSerializer): + username = fields.CharField(source='assigned.demo.name') + + class Meta: + model = Ticket + fields = ('username',) + + owner = Person.objects.create(name='john') + reviewer = Person.objects.create(name='reviewer') + ticket = Ticket.objects.create(assigned=owner, reviewer=reviewer) + serializer = TicketSerializer(ticket, data={'username': 'doe'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'username': 'Can not span more than a relation.'})