This commit is contained in:
Tom Christie 2013-03-18 15:28:06 -07:00
commit 09db96bb45
2 changed files with 112 additions and 0 deletions

View File

@ -147,6 +147,7 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
self._deleted = None
#####
# Methods to determine which fields to use when (de)serializing objects.
@ -387,6 +388,13 @@ class BaseSerializer(WritableField):
# Propagate errors up to our parent
raise NestedValidationError(serializer.errors)
def get_identity(self, data):
"""
This hook is required for bulk update.
It is used to determine the canonical identity of a given object.
"""
return data['id']
@property
def errors(self):
"""
@ -408,9 +416,28 @@ class BaseSerializer(WritableField):
if many:
ret = []
errors = []
update = self.object is not None
if update:
# If this is a bulk update we need to map all the objects
# to a canonical identity so we can determine which
# individual object is being updated for each item in the
# incoming data
objects = self.object
identities = [self.get_identity(self.to_native(obj)) for obj in objects]
identity_to_objects = dict(zip(identities, objects))
for item in data:
if update:
# Determine which object we're updating
identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None)
ret.append(self.from_native(item, None))
errors.append(self._errors)
if update:
self._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
ret = self.from_native(data, files)
@ -450,6 +477,9 @@ class BaseSerializer(WritableField):
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
def delete_object(self, obj):
obj.delete()
def save(self, **kwargs):
"""
Save the deserialized object and return it.
@ -458,6 +488,10 @@ class BaseSerializer(WritableField):
[self.save_object(item, **kwargs) for item in self.object]
else:
self.save_object(self.object, **kwargs)
if self._deleted:
[self.delete_object(item) for item in self._deleted]
return self.object

View File

@ -71,3 +71,81 @@ class BulkCreateSerializerTests(TestCase):
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
class BulkUpdateSerializerTests(TestCase):
def setUp(self):
class Book(object):
object_map = {}
def __init__(self, id, title, author):
self.id = id
self.title = title
self.author = author
def save(self):
Book.object_map[self.id] = self
def delete(self):
del Book.object_map[self.id]
class BookSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100)
def restore_object(self, attrs, instance=None):
if instance:
instance.id = attrs['id']
instance.title = attrs['title']
instance.author = attrs['author']
return instance
return Book(**attrs)
self.Book = Book
self.BookSerializer = BookSerializer
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 2,
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
for item in data:
book = Book(item['id'], item['title'], item['author'])
book.save()
def books(self):
return self.Book.object_map.values()
def test_bulk_update_success(self):
"""
Correct bulk update serialization should return the input data.
"""
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 2,
'title': 'Kafka on the shore',
'author': 'Haruki Murakami'
}
]
serializer = self.BookSerializer(self.books(), data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data
self.assertEqual(data, new_data)