mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-12 04:37:13 +03:00
Merge pull request #753 from maspwr/writable-nested-modelserializer
one-many writable nested modelserializer
This commit is contained in:
commit
ce8ffd390a
|
@ -130,14 +130,14 @@ class BaseSerializer(WritableField):
|
||||||
|
|
||||||
def __init__(self, instance=None, data=None, files=None,
|
def __init__(self, instance=None, data=None, files=None,
|
||||||
context=None, partial=False, many=None,
|
context=None, partial=False, many=None,
|
||||||
allow_delete=False, **kwargs):
|
allow_add_remove=False, **kwargs):
|
||||||
super(BaseSerializer, self).__init__(**kwargs)
|
super(BaseSerializer, self).__init__(**kwargs)
|
||||||
self.opts = self._options_class(self.Meta)
|
self.opts = self._options_class(self.Meta)
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self.root = None
|
self.root = None
|
||||||
self.partial = partial
|
self.partial = partial
|
||||||
self.many = many
|
self.many = many
|
||||||
self.allow_delete = allow_delete
|
self.allow_add_remove = allow_add_remove
|
||||||
|
|
||||||
self.context = context or {}
|
self.context = context or {}
|
||||||
|
|
||||||
|
@ -154,8 +154,8 @@ class BaseSerializer(WritableField):
|
||||||
if many and instance is not None and not hasattr(instance, '__iter__'):
|
if many and instance is not None and not hasattr(instance, '__iter__'):
|
||||||
raise ValueError('instance should be a queryset or other iterable with many=True')
|
raise ValueError('instance should be a queryset or other iterable with many=True')
|
||||||
|
|
||||||
if allow_delete and not many:
|
if allow_add_remove and not many:
|
||||||
raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True')
|
raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Methods to determine which fields to use when (de)serializing objects.
|
# Methods to determine which fields to use when (de)serializing objects.
|
||||||
|
@ -288,8 +288,15 @@ class BaseSerializer(WritableField):
|
||||||
You should override this method to control how deserialized objects
|
You should override this method to control how deserialized objects
|
||||||
are instantiated.
|
are instantiated.
|
||||||
"""
|
"""
|
||||||
|
removed_relations = []
|
||||||
|
|
||||||
|
# Deleted related objects
|
||||||
|
if self._deleted:
|
||||||
|
removed_relations = list(self._deleted)
|
||||||
|
|
||||||
if instance is not None:
|
if instance is not None:
|
||||||
instance.update(attrs)
|
instance.update(attrs)
|
||||||
|
instance._removed_relations = removed_relations
|
||||||
return instance
|
return instance
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
@ -377,6 +384,7 @@ class BaseSerializer(WritableField):
|
||||||
|
|
||||||
# Set the serializer object if it exists
|
# Set the serializer object if it exists
|
||||||
obj = getattr(self.parent.object, field_name) if self.parent.object else None
|
obj = getattr(self.parent.object, field_name) if self.parent.object else None
|
||||||
|
obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
|
||||||
|
|
||||||
if value in (None, ''):
|
if value in (None, ''):
|
||||||
into[(self.source or field_name)] = None
|
into[(self.source or field_name)] = None
|
||||||
|
@ -386,7 +394,8 @@ class BaseSerializer(WritableField):
|
||||||
'data': value,
|
'data': value,
|
||||||
'context': self.context,
|
'context': self.context,
|
||||||
'partial': self.partial,
|
'partial': self.partial,
|
||||||
'many': self.many
|
'many': self.many,
|
||||||
|
'allow_add_remove': self.allow_add_remove
|
||||||
}
|
}
|
||||||
serializer = self.__class__(**kwargs)
|
serializer = self.__class__(**kwargs)
|
||||||
|
|
||||||
|
@ -496,6 +505,9 @@ class BaseSerializer(WritableField):
|
||||||
def save_object(self, obj, **kwargs):
|
def save_object(self, obj, **kwargs):
|
||||||
obj.save(**kwargs)
|
obj.save(**kwargs)
|
||||||
|
|
||||||
|
if self.allow_add_remove and hasattr(obj, '_removed_relations'):
|
||||||
|
[self.delete_object(item) for item in obj._removed_relations]
|
||||||
|
|
||||||
def delete_object(self, obj):
|
def delete_object(self, obj):
|
||||||
obj.delete()
|
obj.delete()
|
||||||
|
|
||||||
|
@ -508,7 +520,7 @@ class BaseSerializer(WritableField):
|
||||||
else:
|
else:
|
||||||
self.save_object(self.object, **kwargs)
|
self.save_object(self.object, **kwargs)
|
||||||
|
|
||||||
if self.allow_delete and self._deleted:
|
if self.allow_add_remove and self._deleted:
|
||||||
[self.delete_object(item) for item in self._deleted]
|
[self.delete_object(item) for item in self._deleted]
|
||||||
|
|
||||||
return self.object
|
return self.object
|
||||||
|
@ -699,6 +711,7 @@ class ModelSerializer(Serializer):
|
||||||
m2m_data = {}
|
m2m_data = {}
|
||||||
related_data = {}
|
related_data = {}
|
||||||
nested_forward_relations = {}
|
nested_forward_relations = {}
|
||||||
|
removed_relations = []
|
||||||
meta = self.opts.model._meta
|
meta = self.opts.model._meta
|
||||||
|
|
||||||
# Reverse fk or one-to-one relations
|
# Reverse fk or one-to-one relations
|
||||||
|
@ -724,6 +737,10 @@ class ModelSerializer(Serializer):
|
||||||
if isinstance(self.fields.get(field_name, None), Serializer):
|
if isinstance(self.fields.get(field_name, None), Serializer):
|
||||||
nested_forward_relations[field_name] = attrs[field_name]
|
nested_forward_relations[field_name] = attrs[field_name]
|
||||||
|
|
||||||
|
# Deleted related objects
|
||||||
|
if self._deleted:
|
||||||
|
removed_relations = list(self._deleted)
|
||||||
|
|
||||||
# Update an existing instance...
|
# Update an existing instance...
|
||||||
if instance is not None:
|
if instance is not None:
|
||||||
for key, val in attrs.items():
|
for key, val in attrs.items():
|
||||||
|
@ -740,6 +757,7 @@ class ModelSerializer(Serializer):
|
||||||
instance._related_data = related_data
|
instance._related_data = related_data
|
||||||
instance._m2m_data = m2m_data
|
instance._m2m_data = m2m_data
|
||||||
instance._nested_forward_relations = nested_forward_relations
|
instance._nested_forward_relations = nested_forward_relations
|
||||||
|
instance._removed_relations = removed_relations
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -764,6 +782,9 @@ class ModelSerializer(Serializer):
|
||||||
|
|
||||||
obj.save(**kwargs)
|
obj.save(**kwargs)
|
||||||
|
|
||||||
|
if self.allow_add_remove and hasattr(obj, '_removed_relations'):
|
||||||
|
[self.delete_object(item) for item in obj._removed_relations]
|
||||||
|
|
||||||
if getattr(obj, '_m2m_data', None):
|
if getattr(obj, '_m2m_data', None):
|
||||||
for accessor_name, object_list in obj._m2m_data.items():
|
for accessor_name, object_list in obj._m2m_data.items():
|
||||||
setattr(obj, accessor_name, object_list)
|
setattr(obj, accessor_name, object_list)
|
||||||
|
@ -773,14 +794,13 @@ class ModelSerializer(Serializer):
|
||||||
for accessor_name, related in obj._related_data.items():
|
for accessor_name, related in obj._related_data.items():
|
||||||
field = self.fields.get(accessor_name, None)
|
field = self.fields.get(accessor_name, None)
|
||||||
if isinstance(field, Serializer):
|
if isinstance(field, Serializer):
|
||||||
# TODO: Following will be needed for reverse FK
|
if field.many:
|
||||||
# if field.many:
|
# Nested reverse fk relationship
|
||||||
# # Nested reverse fk relationship
|
for related_item in related:
|
||||||
# for related_item in related:
|
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
|
||||||
# fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
|
setattr(related_item, fk_field, obj)
|
||||||
# setattr(related_item, fk_field, obj)
|
self.save_object(related_item)
|
||||||
# self.save_object(related_item)
|
else:
|
||||||
# else:
|
|
||||||
# Nested reverse one-one relationship
|
# Nested reverse one-one relationship
|
||||||
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
|
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
|
||||||
setattr(related, fk_field, obj)
|
setattr(related, fk_field, obj)
|
||||||
|
|
|
@ -13,6 +13,15 @@ class OneToOneSource(models.Model):
|
||||||
target = models.OneToOneField(OneToOneTarget, related_name='source')
|
target = models.OneToOneField(OneToOneTarget, related_name='source')
|
||||||
|
|
||||||
|
|
||||||
|
class OneToManyTarget(models.Model):
|
||||||
|
name = models.CharField(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class OneToManySource(models.Model):
|
||||||
|
name = models.CharField(max_length=100)
|
||||||
|
target = models.ForeignKey(OneToManyTarget, related_name='sources')
|
||||||
|
|
||||||
|
|
||||||
class ReverseNestedOneToOneTests(TestCase):
|
class ReverseNestedOneToOneTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
class OneToOneSourceSerializer(serializers.ModelSerializer):
|
class OneToOneSourceSerializer(serializers.ModelSerializer):
|
||||||
|
@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase):
|
||||||
# {'id': 3, 'name': 'target-3', 'source': None}
|
# {'id': 3, 'name': 'target-3', 'source': None}
|
||||||
# ]
|
# ]
|
||||||
# self.assertEqual(serializer.data, expected)
|
# self.assertEqual(serializer.data, expected)
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseNestedOneToManyTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
class OneToManySourceSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = OneToManySource
|
||||||
|
fields = ('id', 'name')
|
||||||
|
|
||||||
|
class OneToManyTargetSerializer(serializers.ModelSerializer):
|
||||||
|
sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = OneToManyTarget
|
||||||
|
fields = ('id', 'name', 'sources')
|
||||||
|
|
||||||
|
self.Serializer = OneToManyTargetSerializer
|
||||||
|
|
||||||
|
target = OneToManyTarget(name='target-1')
|
||||||
|
target.save()
|
||||||
|
for idx in range(1, 4):
|
||||||
|
source = OneToManySource(name='source-%d' % idx, target=target)
|
||||||
|
source.save()
|
||||||
|
|
||||||
|
def test_one_to_many_retrieve(self):
|
||||||
|
queryset = OneToManyTarget.objects.all()
|
||||||
|
serializer = self.Serializer(queryset)
|
||||||
|
expected = [
|
||||||
|
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'}]},
|
||||||
|
]
|
||||||
|
self.assertEqual(serializer.data, expected)
|
||||||
|
|
||||||
|
def test_one_to_many_create(self):
|
||||||
|
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'},
|
||||||
|
{'id': 4, 'name': 'source-4'}]}
|
||||||
|
instance = OneToManyTarget.objects.get(pk=1)
|
||||||
|
serializer = self.Serializer(instance, data=data)
|
||||||
|
self.assertTrue(serializer.is_valid())
|
||||||
|
obj = serializer.save()
|
||||||
|
self.assertEqual(serializer.data, data)
|
||||||
|
self.assertEqual(obj.name, 'target-1')
|
||||||
|
|
||||||
|
# Ensure source 4 is added, and everything else is as
|
||||||
|
# expected.
|
||||||
|
queryset = OneToManyTarget.objects.all()
|
||||||
|
serializer = self.Serializer(queryset)
|
||||||
|
expected = [
|
||||||
|
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'},
|
||||||
|
{'id': 4, 'name': 'source-4'}]}
|
||||||
|
]
|
||||||
|
self.assertEqual(serializer.data, expected)
|
||||||
|
|
||||||
|
def test_one_to_many_create_with_invalid_data(self):
|
||||||
|
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'},
|
||||||
|
{'id': 4}]}
|
||||||
|
serializer = self.Serializer(data=data)
|
||||||
|
self.assertFalse(serializer.is_valid())
|
||||||
|
self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
|
||||||
|
|
||||||
|
def test_one_to_many_update(self):
|
||||||
|
data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'}]}
|
||||||
|
instance = OneToManyTarget.objects.get(pk=1)
|
||||||
|
serializer = self.Serializer(instance, data=data)
|
||||||
|
self.assertTrue(serializer.is_valid())
|
||||||
|
obj = serializer.save()
|
||||||
|
self.assertEqual(serializer.data, data)
|
||||||
|
self.assertEqual(obj.name, 'target-1-updated')
|
||||||
|
|
||||||
|
# Ensure (target 1, source 1) are updated,
|
||||||
|
# and everything else is as expected.
|
||||||
|
queryset = OneToManyTarget.objects.all()
|
||||||
|
serializer = self.Serializer(queryset)
|
||||||
|
expected = [
|
||||||
|
{'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
|
||||||
|
{'id': 2, 'name': 'source-2'},
|
||||||
|
{'id': 3, 'name': 'source-3'}]}
|
||||||
|
|
||||||
|
]
|
||||||
|
self.assertEqual(serializer.data, expected)
|
||||||
|
|
|
@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase):
|
||||||
'author': 'Haruki Murakami'
|
'author': 'Haruki Murakami'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
|
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
|
||||||
self.assertEqual(serializer.is_valid(), True)
|
self.assertEqual(serializer.is_valid(), True)
|
||||||
self.assertEqual(serializer.data, data)
|
self.assertEqual(serializer.data, data)
|
||||||
serializer.save()
|
serializer.save()
|
||||||
|
@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase):
|
||||||
'author': 'Haruki Murakami'
|
'author': 'Haruki Murakami'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
|
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
|
||||||
self.assertEqual(serializer.is_valid(), True)
|
self.assertEqual(serializer.is_valid(), True)
|
||||||
self.assertEqual(serializer.data, data)
|
self.assertEqual(serializer.data, data)
|
||||||
serializer.save()
|
serializer.save()
|
||||||
|
@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase):
|
||||||
{},
|
{},
|
||||||
{'id': ['Enter a whole number.']}
|
{'id': ['Enter a whole number.']}
|
||||||
]
|
]
|
||||||
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
|
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
|
||||||
self.assertEqual(serializer.is_valid(), False)
|
self.assertEqual(serializer.is_valid(), False)
|
||||||
self.assertEqual(serializer.errors, expected_errors)
|
self.assertEqual(serializer.errors, expected_errors)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user