diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 668bcc49d..73cad00f6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -130,14 +130,14 @@ class BaseSerializer(WritableField): def __init__(self, instance=None, data=None, files=None, context=None, partial=False, many=None, - allow_delete=False, **kwargs): + allow_add_remove=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many - self.allow_delete = allow_delete + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -154,8 +154,8 @@ class BaseSerializer(WritableField): if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') - if allow_delete and not many: - raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True') + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -288,8 +288,15 @@ class BaseSerializer(WritableField): You should override this method to control how deserialized objects are instantiated. """ + removed_relations = [] + + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + if instance is not None: instance.update(attrs) + instance._removed_relations = removed_relations return instance return attrs @@ -377,6 +384,7 @@ class BaseSerializer(WritableField): # Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None + obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj if value in (None, ''): into[(self.source or field_name)] = None @@ -386,7 +394,8 @@ class BaseSerializer(WritableField): 'data': value, 'context': self.context, 'partial': self.partial, - 'many': self.many + 'many': self.many, + 'allow_add_remove': self.allow_add_remove } serializer = self.__class__(**kwargs) @@ -496,6 +505,9 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + def delete_object(self, obj): obj.delete() @@ -508,7 +520,7 @@ class BaseSerializer(WritableField): else: self.save_object(self.object, **kwargs) - if self.allow_delete and self._deleted: + if self.allow_add_remove and self._deleted: [self.delete_object(item) for item in self._deleted] return self.object @@ -699,6 +711,7 @@ class ModelSerializer(Serializer): m2m_data = {} related_data = {} nested_forward_relations = {} + removed_relations = [] meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -724,6 +737,10 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -740,6 +757,7 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations + instance._removed_relations = removed_relations return instance @@ -764,6 +782,9 @@ class ModelSerializer(Serializer): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): setattr(obj, accessor_name, object_list) @@ -773,18 +794,17 @@ class ModelSerializer(Serializer): for accessor_name, related in obj._related_data.items(): field = self.fields.get(accessor_name, None) if isinstance(field, Serializer): - # TODO: Following will be needed for reverse FK - # if field.many: - # # Nested reverse fk relationship - # for related_item in related: - # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - # setattr(related_item, fk_field, obj) - # self.save_object(related_item) - # else: - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + if field.many: + # Nested reverse fk relationship + for related_item in related: + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related_item, fk_field, obj) + self.save_object(related_item) + else: + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) else: # Reverse FK or reverse one-one setattr(obj, accessor_name, related) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index e7af65651..20683d4a6 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -13,6 +13,15 @@ class OneToOneSource(models.Model): target = models.OneToOneField(OneToOneTarget, related_name='source') +class OneToManyTarget(models.Model): + name = models.CharField(max_length=100) + + +class OneToManySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(OneToManyTarget, related_name='sources') + + class ReverseNestedOneToOneTests(TestCase): def setUp(self): class OneToOneSourceSerializer(serializers.ModelSerializer): @@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase): # {'id': 3, 'name': 'target-3', 'source': None} # ] # self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): + def setUp(self): + class OneToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToManySource + fields = ('id', 'name') + + class OneToManyTargetSerializer(serializers.ModelSerializer): + sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + + class Meta: + model = OneToManyTarget + fields = ('id', 'name', 'sources') + + self.Serializer = OneToManyTargetSerializer + + target = OneToManyTarget(name='target-1') + target.save() + for idx in range(1, 4): + source = OneToManySource(name='source-%d' % idx, target=target) + source.save() + + def test_one_to_many_retrieve(self): + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]}, + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1') + + # Ensure source 4 is added, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create_with_invalid_data(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4}]} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + + def test_one_to_many_update(self): + data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1-updated') + + # Ensure (target 1, source 1) are updated, + # and everything else is as expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index afc1a1a9f..5328e7331 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase): {}, {'id': ['Enter a whole number.']} ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors)