This commit is contained in:
Xavier Ordoquy 2014-08-18 14:59:57 +00:00
commit a3c79693dc
2 changed files with 88 additions and 2 deletions

View File

@ -16,7 +16,7 @@ import datetime
import inspect import inspect
import types import types
from decimal import Decimal from decimal import Decimal
from django.contrib.contenttypes.generic import GenericForeignKey from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
from django.core.paginator import Page from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets from django.forms import widgets
@ -527,6 +527,11 @@ class BaseSerializer(WritableField):
# Determine which object we're updating # Determine which object we're updating
identity = self.get_identity(item) identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None) self.object = identity_to_objects.pop(identity, None)
if not self.object and getattr(self.opts, 'model', None):
try:
self.object = self.opts.model.objects.get(id=self.get_identity(item))
except ObjectDoesNotExist:
pass
if self.object is None and not self.allow_add_remove: if self.object is None and not self.allow_add_remove:
ret.append(None) ret.append(None)
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
@ -942,6 +947,7 @@ class ModelSerializer(Serializer):
m2m_data = {} m2m_data = {}
related_data = {} related_data = {}
nested_forward_relations = {} nested_forward_relations = {}
generic_fk = []
meta = self.opts.model._meta meta = self.opts.model._meta
# Reverse fk or one-to-one relations # Reverse fk or one-to-one relations
@ -960,6 +966,8 @@ class ModelSerializer(Serializer):
for field in meta.many_to_many + meta.virtual_fields: for field in meta.many_to_many + meta.virtual_fields:
if isinstance(field, GenericForeignKey): if isinstance(field, GenericForeignKey):
continue continue
if isinstance(field, GenericRelation):
generic_fk.append(field.name)
if field.name in attrs: if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name) m2m_data[field.name] = attrs.pop(field.name)
@ -983,6 +991,7 @@ class ModelSerializer(Serializer):
# saved the model get hidden away on these # saved the model get hidden away on these
# private attributes, so we can deal with them # private attributes, so we can deal with them
# at the point of save. # at the point of save.
instance._generic_fk = generic_fk
instance._related_data = related_data instance._related_data = related_data
instance._m2m_data = m2m_data instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations instance._nested_forward_relations = nested_forward_relations
@ -1013,6 +1022,12 @@ class ModelSerializer(Serializer):
if getattr(obj, '_m2m_data', None): if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in obj._m2m_data.items(): for accessor_name, object_list in obj._m2m_data.items():
if accessor_name in getattr(obj, '_generic_fk', []):
# We are dealing with a reversed generic FK
setattr(obj, accessor_name, object_list)
[self.save_object(o) for o in object_list if not isinstance(o, GenericRelation)]
if accessor_name not in getattr(obj, '_generic_fk', []):
# We need to save m2m data before linking the objects together
setattr(obj, accessor_name, object_list) setattr(obj, accessor_name, object_list)
del(obj._m2m_data) del(obj._m2m_data)

View File

@ -345,3 +345,74 @@ class NestedModelSerializerUpdateTests(TestCase):
result = deserialize.object result = deserialize.object
result.save() result.save()
self.assertEqual(result.id, john.id) self.assertEqual(result.id, john.id)
def test_creation_with_nested_many_to_many_relation(self):
class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = models.ManyToManyTarget
class ManyToManySourceSerializer(serializers.ModelSerializer):
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
class Meta:
model = models.ManyToManySource
data = {
'name': 'source',
'targets': [{
'name': 'target1'
}, {
'name': 'another target'
}]
}
source_count = models.ManyToManySource.objects.count()
target_count = models.ManyToManyTarget.objects.count()
deserialize = ManyToManySourceSerializer(data=data)
self.assertTrue(deserialize.is_valid(), deserialize.errors)
deserialize.save()
self.assertEqual(models.ManyToManySource.objects.count(), source_count + 1)
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count + 2)
def test_update_with_nested_many_to_many_relation(self):
class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = models.ManyToManyTarget
class ManyToManySourceSerializer(serializers.ModelSerializer):
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
class Meta:
model = models.ManyToManySource
source = models.ManyToManySource.objects.create(name='source')
target1 = models.ManyToManyTarget.objects.create(name='target1')
target2 = models.ManyToManyTarget.objects.create(name='target2')
source.targets = [target1]
data = {
'id': source.id,
'name': source.name + '0',
'targets': [{
'id': target1.id,
'name': target1.name + '1',
}, {
'id': target2.id,
'name': target2.name + '2',
}]
}
source_count = models.ManyToManySource.objects.count()
target_count = models.ManyToManyTarget.objects.count()
deserialize = ManyToManySourceSerializer(data=data, instance=source)
self.assertTrue(deserialize.is_valid(), deserialize.errors)
deserialize.save()
self.assertEqual(models.ManyToManySource.objects.count(), source_count)
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count)
# Were the models updated ?
self.assertEqual(source.name, 'source0')
alt_target1 = models.ManyToManyTarget.objects.get(id=target1.id)
self.assertEqual(alt_target1.name, target1.name + '1')
alt_target2 = models.ManyToManyTarget.objects.get(id=target2.id)
self.assertEqual(alt_target2.name, target2.name + '2')