Merge pull request #753 from maspwr/writable-nested-modelserializer

one-many writable nested modelserializer
This commit is contained in:
Tom Christie 2013-04-08 14:48:45 -07:00
commit ce8ffd390a
3 changed files with 139 additions and 21 deletions

View File

@ -130,14 +130,14 @@ class BaseSerializer(WritableField):
def __init__(self, instance=None, data=None, files=None, def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=None, context=None, partial=False, many=None,
allow_delete=False, **kwargs): allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
self.parent = None self.parent = None
self.root = None self.root = None
self.partial = partial self.partial = partial
self.many = many self.many = many
self.allow_delete = allow_delete self.allow_add_remove = allow_add_remove
self.context = context or {} self.context = context or {}
@ -154,8 +154,8 @@ class BaseSerializer(WritableField):
if many and instance is not None and not hasattr(instance, '__iter__'): if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True') raise ValueError('instance should be a queryset or other iterable with many=True')
if allow_delete and not many: if allow_add_remove and not many:
raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True') raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
##### #####
# Methods to determine which fields to use when (de)serializing objects. # Methods to determine which fields to use when (de)serializing objects.
@ -288,8 +288,15 @@ class BaseSerializer(WritableField):
You should override this method to control how deserialized objects You should override this method to control how deserialized objects
are instantiated. are instantiated.
""" """
removed_relations = []
# Deleted related objects
if self._deleted:
removed_relations = list(self._deleted)
if instance is not None: if instance is not None:
instance.update(attrs) instance.update(attrs)
instance._removed_relations = removed_relations
return instance return instance
return attrs return attrs
@ -377,6 +384,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists # Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if value in (None, ''): if value in (None, ''):
into[(self.source or field_name)] = None into[(self.source or field_name)] = None
@ -386,7 +394,8 @@ class BaseSerializer(WritableField):
'data': value, 'data': value,
'context': self.context, 'context': self.context,
'partial': self.partial, 'partial': self.partial,
'many': self.many 'many': self.many,
'allow_add_remove': self.allow_add_remove
} }
serializer = self.__class__(**kwargs) serializer = self.__class__(**kwargs)
@ -496,6 +505,9 @@ class BaseSerializer(WritableField):
def save_object(self, obj, **kwargs): def save_object(self, obj, **kwargs):
obj.save(**kwargs) obj.save(**kwargs)
if self.allow_add_remove and hasattr(obj, '_removed_relations'):
[self.delete_object(item) for item in obj._removed_relations]
def delete_object(self, obj): def delete_object(self, obj):
obj.delete() obj.delete()
@ -508,7 +520,7 @@ class BaseSerializer(WritableField):
else: else:
self.save_object(self.object, **kwargs) self.save_object(self.object, **kwargs)
if self.allow_delete and self._deleted: if self.allow_add_remove and self._deleted:
[self.delete_object(item) for item in self._deleted] [self.delete_object(item) for item in self._deleted]
return self.object return self.object
@ -699,6 +711,7 @@ class ModelSerializer(Serializer):
m2m_data = {} m2m_data = {}
related_data = {} related_data = {}
nested_forward_relations = {} nested_forward_relations = {}
removed_relations = []
meta = self.opts.model._meta meta = self.opts.model._meta
# Reverse fk or one-to-one relations # Reverse fk or one-to-one relations
@ -724,6 +737,10 @@ class ModelSerializer(Serializer):
if isinstance(self.fields.get(field_name, None), Serializer): if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name] nested_forward_relations[field_name] = attrs[field_name]
# Deleted related objects
if self._deleted:
removed_relations = list(self._deleted)
# Update an existing instance... # Update an existing instance...
if instance is not None: if instance is not None:
for key, val in attrs.items(): for key, val in attrs.items():
@ -740,6 +757,7 @@ class ModelSerializer(Serializer):
instance._related_data = related_data instance._related_data = related_data
instance._m2m_data = m2m_data instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations instance._nested_forward_relations = nested_forward_relations
instance._removed_relations = removed_relations
return instance return instance
@ -764,6 +782,9 @@ class ModelSerializer(Serializer):
obj.save(**kwargs) obj.save(**kwargs)
if self.allow_add_remove and hasattr(obj, '_removed_relations'):
[self.delete_object(item) for item in obj._removed_relations]
if getattr(obj, '_m2m_data', None): if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in obj._m2m_data.items(): for accessor_name, object_list in obj._m2m_data.items():
setattr(obj, accessor_name, object_list) setattr(obj, accessor_name, object_list)
@ -773,18 +794,17 @@ class ModelSerializer(Serializer):
for accessor_name, related in obj._related_data.items(): for accessor_name, related in obj._related_data.items():
field = self.fields.get(accessor_name, None) field = self.fields.get(accessor_name, None)
if isinstance(field, Serializer): if isinstance(field, Serializer):
# TODO: Following will be needed for reverse FK if field.many:
# if field.many: # Nested reverse fk relationship
# # Nested reverse fk relationship for related_item in related:
# for related_item in related: fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
# fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related_item, fk_field, obj)
# setattr(related_item, fk_field, obj) self.save_object(related_item)
# self.save_object(related_item) else:
# else: # Nested reverse one-one relationship
# Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related, fk_field, obj)
setattr(related, fk_field, obj) self.save_object(related)
self.save_object(related)
else: else:
# Reverse FK or reverse one-one # Reverse FK or reverse one-one
setattr(obj, accessor_name, related) setattr(obj, accessor_name, related)

View File

@ -13,6 +13,15 @@ class OneToOneSource(models.Model):
target = models.OneToOneField(OneToOneTarget, related_name='source') target = models.OneToOneField(OneToOneTarget, related_name='source')
class OneToManyTarget(models.Model):
name = models.CharField(max_length=100)
class OneToManySource(models.Model):
name = models.CharField(max_length=100)
target = models.ForeignKey(OneToManyTarget, related_name='sources')
class ReverseNestedOneToOneTests(TestCase): class ReverseNestedOneToOneTests(TestCase):
def setUp(self): def setUp(self):
class OneToOneSourceSerializer(serializers.ModelSerializer): class OneToOneSourceSerializer(serializers.ModelSerializer):
@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase):
# {'id': 3, 'name': 'target-3', 'source': None} # {'id': 3, 'name': 'target-3', 'source': None}
# ] # ]
# self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
class ReverseNestedOneToManyTests(TestCase):
def setUp(self):
class OneToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = OneToManySource
fields = ('id', 'name')
class OneToManyTargetSerializer(serializers.ModelSerializer):
sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
class Meta:
model = OneToManyTarget
fields = ('id', 'name', 'sources')
self.Serializer = OneToManyTargetSerializer
target = OneToManyTarget(name='target-1')
target.save()
for idx in range(1, 4):
source = OneToManySource(name='source-%d' % idx, target=target)
source.save()
def test_one_to_many_retrieve(self):
queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]},
]
self.assertEqual(serializer.data, expected)
def test_one_to_many_create(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]}
instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1')
# Ensure source 4 is added, and everything else is as
# expected.
queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]}
]
self.assertEqual(serializer.data, expected)
def test_one_to_many_create_with_invalid_data(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'},
{'id': 4}]}
serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
def test_one_to_many_update(self):
data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]}
instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1-updated')
# Ensure (target 1, source 1) are updated,
# and everything else is as expected.
queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset)
expected = [
{'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]}
]
self.assertEqual(serializer.data, expected)

View File

@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami' 'author': 'Haruki Murakami'
} }
] ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
serializer.save() serializer.save()
@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami' 'author': 'Haruki Murakami'
} }
] ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
serializer.save() serializer.save()
@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase):
{}, {},
{'id': ['Enter a whole number.']} {'id': ['Enter a whole number.']}
] ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) self.assertEqual(serializer.errors, expected_errors)