From 3006e3825f29e920f881b816fd71566bf0e8d341 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Tue, 12 Mar 2013 20:59:25 -0700 Subject: [PATCH 1/7] One-to-one writable, nested serializer support --- rest_framework/serializers.py | 44 ++++++-- rest_framework/tests/nesting.py | 125 ++++++++++++++++++++++ rest_framework/tests/serializer_nested.py | 4 +- 3 files changed, 160 insertions(+), 13 deletions(-) create mode 100644 rest_framework/tests/nesting.py diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f83451d37..893db2ece 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -26,13 +26,17 @@ class NestedValidationError(ValidationError): if the messages are a list of error messages. In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. + then the child's `serializer.errors` will be a list of dicts. In the case + of a single child, the `serializer.errors` will be a dict. We need to override the default behavior to get properly nested error dicts. """ def __init__(self, message): - self.messages = message + if isinstance(message, dict): + self.messages = [message] + else: + self.messages = message class DictWithMetadata(dict): @@ -143,6 +147,7 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None + self._delete = False ##### # Methods to determine which fields to use when (de)serializing objects. @@ -354,15 +359,19 @@ class BaseSerializer(WritableField): raise ValidationError(self.error_messages['required']) return - if self.parent.object: - # Set the serializer object if it exists - obj = getattr(self.parent.object, field_name) - self.object = obj + # Set the serializer object if it exists + obj = getattr(self.parent.object, field_name) if self.parent.object else None if value in (None, ''): - into[(self.source or field_name)] = None + if isinstance(self, ModelSerializer): + self._delete = True + self.object = obj + into[(self.source or field_name)] = self + else: + into[(self.source or field_name)] = None else: kwargs = { + 'instance': obj, 'data': value, 'context': self.context, 'partial': self.partial, @@ -371,8 +380,10 @@ class BaseSerializer(WritableField): serializer = self.__class__(**kwargs) if serializer.is_valid(): - self.object = serializer.object - into[self.source or field_name] = serializer.object + if isinstance(serializer, ModelSerializer): + into[self.source or field_name] = serializer + else: + into[self.source or field_name] = serializer.object else: # Propagate errors up to our parent raise NestedValidationError(serializer.errors) @@ -664,10 +675,17 @@ class ModelSerializer(Serializer): if instance: return self.full_clean(instance) - def save_object(self, obj): + def save_object(self, obj, parent=None, fk_field=None): """ Save the deserialized object and return it. """ + if self._delete: + obj.delete() + return + + if parent and fk_field: + setattr(self.object, fk_field, parent) + obj.save() if getattr(self, 'm2m_data', None): @@ -677,7 +695,11 @@ class ModelSerializer(Serializer): if getattr(self, 'related_data', None): for accessor_name, object_list in self.related_data.items(): - setattr(self.object, accessor_name, object_list) + if isinstance(object_list, ModelSerializer): + fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name + object_list.save_object(object_list.object, parent=self.object, fk_field=fk_field) + else: + setattr(self.object, accessor_name, object_list) self.related_data = {} diff --git a/rest_framework/tests/nesting.py b/rest_framework/tests/nesting.py new file mode 100644 index 000000000..35b7a365d --- /dev/null +++ b/rest_framework/tests/nesting.py @@ -0,0 +1,125 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class OneToOneTarget(models.Model): + name = models.CharField(max_length=100) + + +class OneToOneTargetSource(models.Model): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='target_source') + + +class OneToOneSource(models.Model): + name = models.CharField(max_length=100) + target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') + + +class OneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSource + exclude = ('target_source', ) + + +class OneToOneTargetSourceSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() + + class Meta: + model = OneToOneTargetSource + exclude = ('target', ) + +class OneToOneTargetSerializer(serializers.ModelSerializer): + target_source = OneToOneTargetSourceSerializer() + + class Meta: + model = OneToOneTarget + + +class NestedOneToOneTests(TestCase): + def setUp(self): + for idx in range(1, 4): + target = OneToOneTarget(name='target-%d' % idx) + target.save() + target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) + target_source.save() + source = OneToOneSource(name='source-%d' % idx, target_source=target_source) + source.save() + + def test_one_to_one_retrieve(self): + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}} + ] + self.assertEqual(serializer.data, expected) + + + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, + {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_delete(self): + data = {'id': 3, 'name': 'target-3', 'target_source': None} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + + # Ensure (target_source 3, source 3) are deleted, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': None} + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py index fcf644c75..299c3bc5a 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/serializer_nested.py @@ -124,7 +124,7 @@ class WritableNestedSerializerObjectTests(TestCase): def __init__(self, order, title, duration): self.order, self.title, self.duration = order, title, duration - def __cmp__(self, other): + def __eq__(self, other): return ( self.order == other.order and self.title == other.title and @@ -135,7 +135,7 @@ class WritableNestedSerializerObjectTests(TestCase): def __init__(self, album_name, artist, tracks): self.album_name, self.artist, self.tracks = album_name, artist, tracks - def __cmp__(self, other): + def __eq__(self, other): return ( self.album_name == other.album_name and self.artist == other.artist and From 47492e3ef4e24ecd155091247e479851789ee8e9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 15 Mar 2013 19:22:31 +0000 Subject: [PATCH 2/7] Clean out ModelSerializer special casing from Serializer.field_from_native --- rest_framework/serializers.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f073e00aa..5dadebb28 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -379,11 +379,7 @@ class BaseSerializer(WritableField): serializer = self.__class__(**kwargs) if serializer.is_valid(): - if isinstance(serializer, ModelSerializer): - into[self.source or field_name] = serializer - else: - into[self.source or field_name] = serializer.object - # into[self.source or field_name] = serializer.object + into[self.source or field_name] = serializer.object else: # Propagate errors up to our parent raise NestedValidationError(serializer.errors) @@ -681,12 +677,6 @@ class ModelSerializer(Serializer): if instance: return self.full_clean(instance) -# def save_object(self, obj, **kwargs): -# """ -# Save the deserialized object and return it. -# """ -# obj.save(**kwargs) -# ======= def save_object(self, obj, parent=None, fk_field=None, **kwargs): """ Save the deserialized object and return it. @@ -706,13 +696,10 @@ class ModelSerializer(Serializer): if related is None: previous = getattr(self.object, accessor_name, related) previous.delete() - elif isinstance(related, ModelSerializer): - # print related.object - # print related.related_data, related.m2m_data + elif isinstance(related, models.Model): fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - related.save_object(related.object, parent=self.object, fk_field=fk_field) - # setattr(related, fk_field, obj) - # related.save(**kwargs) + setattr(related, fk_field, obj) + self.save_object(related) else: setattr(self.object, accessor_name, related) obj._related_data = {} From 32e0e5e18c84e7b720c74df8aeba26e0f335bbf6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 15 Mar 2013 19:55:32 +0000 Subject: [PATCH 3/7] Remove erronous _delete attribute --- rest_framework/serializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5dadebb28..691d2aab9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -147,7 +147,6 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None - self._delete = False ##### # Methods to determine which fields to use when (de)serializing objects. From 56653111a6848f6ef5d4bb645b87cbcaf5bffba1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 15 Mar 2013 19:57:57 +0000 Subject: [PATCH 4/7] Remove unneeded arguments to save_object --- rest_framework/serializers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 691d2aab9..ebc2eec95 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -676,13 +676,10 @@ class ModelSerializer(Serializer): if instance: return self.full_clean(instance) - def save_object(self, obj, parent=None, fk_field=None, **kwargs): + def save_object(self, obj, **kwargs): """ Save the deserialized object and return it. """ - if parent and fk_field: - setattr(self.object, fk_field, parent) - obj.save(**kwargs) if getattr(obj, '_m2m_data', None): From ccf551201feb96451ffdc5d824bb0681596bcdae Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 16 Mar 2013 07:32:50 +0000 Subject: [PATCH 5/7] Clean up and comment `restore_object` --- rest_framework/serializers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ebc2eec95..fb7722626 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -638,31 +638,38 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + meta = self.opts.model._meta - # Reverse fk relations - for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + # Reverse fk or one-to-one relations + for (obj, model) in meta.get_all_related_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: related_data[field_name] = attrs.pop(field_name) # Reverse m2m relations - for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + for (obj, model) in meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: m2m_data[field_name] = attrs.pop(field_name) # Forward m2m relations - for field in self.opts.model._meta.many_to_many: + for field in meta.many_to_many: if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) + # ...or create a new instance else: instance = self.opts.model(**attrs) + # Any relations that cannot be set until we've + # saved the model get hidden away on these + # private attributes, so we can deal with them + # at the point of save. instance._related_data = related_data instance._m2m_data = m2m_data From 3ff103ad043420b430cc2052241994d597b1fe8a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 16 Mar 2013 07:35:27 +0000 Subject: [PATCH 6/7] Fixes to save_object --- rest_framework/serializers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index fb7722626..5826a1730 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -691,21 +691,23 @@ class ModelSerializer(Serializer): if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): - setattr(self.object, accessor_name, object_list) - obj._m2m_data = {} + setattr(obj, accessor_name, object_list) + del(obj._m2m_data) if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): if related is None: - previous = getattr(self.object, accessor_name, related) - previous.delete() + previous = getattr(obj, accessor_name, related) + if previous: + previous.delete() elif isinstance(related, models.Model): fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related, fk_field, obj) self.save_object(related) else: setattr(self.object, accessor_name, related) - obj._related_data = {} + setattr(obj, accessor_name, related) + del(obj._related_data) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): From 66bdd608e1e4bbb02a815104572b80034d73aa6b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 16 Mar 2013 07:35:44 +0000 Subject: [PATCH 7/7] Fixes to save_object --- rest_framework/serializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5826a1730..21336dc29 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -705,7 +705,6 @@ class ModelSerializer(Serializer): setattr(related, fk_field, obj) self.save_object(related) else: - setattr(self.object, accessor_name, related) setattr(obj, accessor_name, related) del(obj._related_data)