This commit is contained in:
Alex Louden 2014-08-18 19:13:37 +00:00
commit d96426fdd7
3 changed files with 143 additions and 2 deletions

View File

@ -16,7 +16,7 @@ import datetime
import inspect
import types
from decimal import Decimal
from django.contrib.contenttypes.generic import GenericForeignKey
from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
from django.core.paginator import Page
from django.db import models
from django.forms import widgets
@ -527,6 +527,11 @@ class BaseSerializer(WritableField):
# Determine which object we're updating
identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None)
if not self.object and getattr(self.opts, 'model', None):
try:
self.object = self.opts.model.objects.get(id=self.get_identity(item))
except ObjectDoesNotExist:
pass
if self.object is None and not self.allow_add_remove:
ret.append(None)
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
@ -942,6 +947,7 @@ class ModelSerializer(Serializer):
m2m_data = {}
related_data = {}
nested_forward_relations = {}
generic_fk = []
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@ -960,6 +966,8 @@ class ModelSerializer(Serializer):
for field in meta.many_to_many + meta.virtual_fields:
if isinstance(field, GenericForeignKey):
continue
if isinstance(field, GenericRelation):
generic_fk.append(field.name)
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
@ -983,6 +991,7 @@ class ModelSerializer(Serializer):
# saved the model get hidden away on these
# private attributes, so we can deal with them
# at the point of save.
instance._generic_fk = generic_fk
instance._related_data = related_data
instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations
@ -1013,7 +1022,13 @@ class ModelSerializer(Serializer):
if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in obj._m2m_data.items():
setattr(obj, accessor_name, object_list)
if accessor_name in getattr(obj, '_generic_fk', []):
# We are dealing with a reversed generic FK
setattr(obj, accessor_name, object_list)
[self.save_object(o) for o in object_list if not isinstance(o, GenericRelation)]
if accessor_name not in getattr(obj, '_generic_fk', []):
# We need to save m2m data before linking the objects together
setattr(obj, accessor_name, object_list)
del(obj._m2m_data)
if getattr(obj, '_related_data', None):

View File

@ -0,0 +1,55 @@
from django.db import models
from rest_framework import serializers
from django.test import TestCase
class TestModel(models.Model):
class Meta:
app_label = 'tests'
class Person(TestModel):
name = models.CharField(max_length=200)
class Group(TestModel):
name = models.TextField()
members = models.ManyToManyField(Person)
class GroupSerialiser(serializers.ModelSerializer):
members = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = Group
fields = (
'id',
'name',
'members'
)
class TestPrimaryKeyRelatedRelation(TestCase):
def test_deserialize_group(self):
person = Person.objects.create(name='Person')
data = {
'name': 'Group Name',
'members': [person.id]
}
serializer = GroupSerialiser(data=data, files=None)
self.assertTrue(serializer.is_valid())
obj = serializer.object
serializer.save_object(obj)
self.assertEquals(obj.members.count(), 1)
member = obj.members.all()[0]
self.assertEqual(member, person)

View File

@ -345,3 +345,74 @@ class NestedModelSerializerUpdateTests(TestCase):
result = deserialize.object
result.save()
self.assertEqual(result.id, john.id)
def test_creation_with_nested_many_to_many_relation(self):
class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = models.ManyToManyTarget
class ManyToManySourceSerializer(serializers.ModelSerializer):
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
class Meta:
model = models.ManyToManySource
data = {
'name': 'source',
'targets': [{
'name': 'target1'
}, {
'name': 'another target'
}]
}
source_count = models.ManyToManySource.objects.count()
target_count = models.ManyToManyTarget.objects.count()
deserialize = ManyToManySourceSerializer(data=data)
self.assertTrue(deserialize.is_valid(), deserialize.errors)
deserialize.save()
self.assertEqual(models.ManyToManySource.objects.count(), source_count + 1)
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count + 2)
def test_update_with_nested_many_to_many_relation(self):
class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = models.ManyToManyTarget
class ManyToManySourceSerializer(serializers.ModelSerializer):
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
class Meta:
model = models.ManyToManySource
source = models.ManyToManySource.objects.create(name='source')
target1 = models.ManyToManyTarget.objects.create(name='target1')
target2 = models.ManyToManyTarget.objects.create(name='target2')
source.targets = [target1]
data = {
'id': source.id,
'name': source.name + '0',
'targets': [{
'id': target1.id,
'name': target1.name + '1',
}, {
'id': target2.id,
'name': target2.name + '2',
}]
}
source_count = models.ManyToManySource.objects.count()
target_count = models.ManyToManyTarget.objects.count()
deserialize = ManyToManySourceSerializer(data=data, instance=source)
self.assertTrue(deserialize.is_valid(), deserialize.errors)
deserialize.save()
self.assertEqual(models.ManyToManySource.objects.count(), source_count)
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count)
# Were the models updated ?
self.assertEqual(source.name, 'source0')
alt_target1 = models.ManyToManyTarget.objects.get(id=target1.id)
self.assertEqual(alt_target1.name, target1.name + '1')
alt_target2 = models.ManyToManyTarget.objects.get(id=target2.id)
self.assertEqual(alt_target2.name, target2.name + '2')