From 82c5a9de750f3657a2bde5fa7eaffd7fa0279a5a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 18 Mar 2013 22:24:46 +0000 Subject: [PATCH] Added bulk update functionality and test --- rest_framework/serializers.py | 34 ++++++++ .../tests/serializer_bulk_update.py | 78 +++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a81cbc291..a39bdee76 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -147,6 +147,7 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None + self._deleted = None ##### # Methods to determine which fields to use when (de)serializing objects. @@ -387,6 +388,13 @@ class BaseSerializer(WritableField): # Propagate errors up to our parent raise NestedValidationError(serializer.errors) + def get_identity(self, data): + """ + This hook is required for bulk update. + It is used to determine the canonical identity of a given object. + """ + return data['id'] + @property def errors(self): """ @@ -408,9 +416,28 @@ class BaseSerializer(WritableField): if many: ret = [] errors = [] + update = self.object is not None + + if update: + # If this is a bulk update we need to map all the objects + # to a canonical identity so we can determine which + # individual object is being updated for each item in the + # incoming data + objects = self.object + identities = [self.get_identity(self.to_native(obj)) for obj in objects] + identity_to_objects = dict(zip(identities, objects)) + for item in data: + if update: + # Determine which object we're updating + identity = self.get_identity(item) + self.object = identity_to_objects.pop(identity, None) + ret.append(self.from_native(item, None)) errors.append(self._errors) + + if update: + self._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] else: ret = self.from_native(data, files) @@ -450,6 +477,9 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) + def delete_object(self, obj): + obj.delete() + def save(self, **kwargs): """ Save the deserialized object and return it. @@ -458,6 +488,10 @@ class BaseSerializer(WritableField): [self.save_object(item, **kwargs) for item in self.object] else: self.save_object(self.object, **kwargs) + + if self._deleted: + [self.delete_object(item) for item in self._deleted] + return self.object diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index 3ecb23edd..6ed8741c8 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -71,3 +71,81 @@ class BulkCreateSerializerTests(TestCase): self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) + +class BulkUpdateSerializerTests(TestCase): + + def setUp(self): + class Book(object): + object_map = {} + + def __init__(self, id, title, author): + self.id = id + self.title = title + self.author = author + + def save(self): + Book.object_map[self.id] = self + + def delete(self): + del Book.object_map[self.id] + + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + def restore_object(self, attrs, instance=None): + if instance: + instance.id = attrs['id'] + instance.title = attrs['title'] + instance.author = attrs['author'] + return instance + return Book(**attrs) + + self.Book = Book + self.BookSerializer = BookSerializer + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + for item in data: + book = Book(item['id'], item['title'], item['author']) + book.save() + + def books(self): + return self.Book.object_map.values() + + def test_bulk_update_success(self): + """ + Correct bulk update serialization should return the input data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 2, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + self.assertEqual(data, new_data)