mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 20:40:14 +03:00
Merge 2b908d2e74
into 3b899c9d57
This commit is contained in:
commit
d96426fdd7
|
@ -16,7 +16,7 @@ import datetime
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
import types
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from django.contrib.contenttypes.generic import GenericForeignKey
|
from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
|
||||||
from django.core.paginator import Page
|
from django.core.paginator import Page
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.forms import widgets
|
from django.forms import widgets
|
||||||
|
@ -527,6 +527,11 @@ class BaseSerializer(WritableField):
|
||||||
# Determine which object we're updating
|
# Determine which object we're updating
|
||||||
identity = self.get_identity(item)
|
identity = self.get_identity(item)
|
||||||
self.object = identity_to_objects.pop(identity, None)
|
self.object = identity_to_objects.pop(identity, None)
|
||||||
|
if not self.object and getattr(self.opts, 'model', None):
|
||||||
|
try:
|
||||||
|
self.object = self.opts.model.objects.get(id=self.get_identity(item))
|
||||||
|
except ObjectDoesNotExist:
|
||||||
|
pass
|
||||||
if self.object is None and not self.allow_add_remove:
|
if self.object is None and not self.allow_add_remove:
|
||||||
ret.append(None)
|
ret.append(None)
|
||||||
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
|
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
|
||||||
|
@ -942,6 +947,7 @@ class ModelSerializer(Serializer):
|
||||||
m2m_data = {}
|
m2m_data = {}
|
||||||
related_data = {}
|
related_data = {}
|
||||||
nested_forward_relations = {}
|
nested_forward_relations = {}
|
||||||
|
generic_fk = []
|
||||||
meta = self.opts.model._meta
|
meta = self.opts.model._meta
|
||||||
|
|
||||||
# Reverse fk or one-to-one relations
|
# Reverse fk or one-to-one relations
|
||||||
|
@ -960,6 +966,8 @@ class ModelSerializer(Serializer):
|
||||||
for field in meta.many_to_many + meta.virtual_fields:
|
for field in meta.many_to_many + meta.virtual_fields:
|
||||||
if isinstance(field, GenericForeignKey):
|
if isinstance(field, GenericForeignKey):
|
||||||
continue
|
continue
|
||||||
|
if isinstance(field, GenericRelation):
|
||||||
|
generic_fk.append(field.name)
|
||||||
if field.name in attrs:
|
if field.name in attrs:
|
||||||
m2m_data[field.name] = attrs.pop(field.name)
|
m2m_data[field.name] = attrs.pop(field.name)
|
||||||
|
|
||||||
|
@ -983,6 +991,7 @@ class ModelSerializer(Serializer):
|
||||||
# saved the model get hidden away on these
|
# saved the model get hidden away on these
|
||||||
# private attributes, so we can deal with them
|
# private attributes, so we can deal with them
|
||||||
# at the point of save.
|
# at the point of save.
|
||||||
|
instance._generic_fk = generic_fk
|
||||||
instance._related_data = related_data
|
instance._related_data = related_data
|
||||||
instance._m2m_data = m2m_data
|
instance._m2m_data = m2m_data
|
||||||
instance._nested_forward_relations = nested_forward_relations
|
instance._nested_forward_relations = nested_forward_relations
|
||||||
|
@ -1013,6 +1022,12 @@ class ModelSerializer(Serializer):
|
||||||
|
|
||||||
if getattr(obj, '_m2m_data', None):
|
if getattr(obj, '_m2m_data', None):
|
||||||
for accessor_name, object_list in obj._m2m_data.items():
|
for accessor_name, object_list in obj._m2m_data.items():
|
||||||
|
if accessor_name in getattr(obj, '_generic_fk', []):
|
||||||
|
# We are dealing with a reversed generic FK
|
||||||
|
setattr(obj, accessor_name, object_list)
|
||||||
|
[self.save_object(o) for o in object_list if not isinstance(o, GenericRelation)]
|
||||||
|
if accessor_name not in getattr(obj, '_generic_fk', []):
|
||||||
|
# We need to save m2m data before linking the objects together
|
||||||
setattr(obj, accessor_name, object_list)
|
setattr(obj, accessor_name, object_list)
|
||||||
del(obj._m2m_data)
|
del(obj._m2m_data)
|
||||||
|
|
||||||
|
|
55
rest_framework/tests/test_pk_related_creation.py
Normal file
55
rest_framework/tests/test_pk_related_creation.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
from django.db import models
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(models.Model):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
app_label = 'tests'
|
||||||
|
|
||||||
|
|
||||||
|
class Person(TestModel):
|
||||||
|
name = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
|
class Group(TestModel):
|
||||||
|
name = models.TextField()
|
||||||
|
members = models.ManyToManyField(Person)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSerialiser(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
members = serializers.PrimaryKeyRelatedField(many=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Group
|
||||||
|
fields = (
|
||||||
|
'id',
|
||||||
|
'name',
|
||||||
|
'members'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrimaryKeyRelatedRelation(TestCase):
|
||||||
|
|
||||||
|
def test_deserialize_group(self):
|
||||||
|
|
||||||
|
person = Person.objects.create(name='Person')
|
||||||
|
data = {
|
||||||
|
'name': 'Group Name',
|
||||||
|
'members': [person.id]
|
||||||
|
}
|
||||||
|
|
||||||
|
serializer = GroupSerialiser(data=data, files=None)
|
||||||
|
|
||||||
|
self.assertTrue(serializer.is_valid())
|
||||||
|
|
||||||
|
obj = serializer.object
|
||||||
|
serializer.save_object(obj)
|
||||||
|
|
||||||
|
self.assertEquals(obj.members.count(), 1)
|
||||||
|
|
||||||
|
member = obj.members.all()[0]
|
||||||
|
self.assertEqual(member, person)
|
|
@ -345,3 +345,74 @@ class NestedModelSerializerUpdateTests(TestCase):
|
||||||
result = deserialize.object
|
result = deserialize.object
|
||||||
result.save()
|
result.save()
|
||||||
self.assertEqual(result.id, john.id)
|
self.assertEqual(result.id, john.id)
|
||||||
|
|
||||||
|
def test_creation_with_nested_many_to_many_relation(self):
|
||||||
|
class ManyToManyTargetSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = models.ManyToManyTarget
|
||||||
|
|
||||||
|
class ManyToManySourceSerializer(serializers.ModelSerializer):
|
||||||
|
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
|
||||||
|
class Meta:
|
||||||
|
model = models.ManyToManySource
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'name': 'source',
|
||||||
|
'targets': [{
|
||||||
|
'name': 'target1'
|
||||||
|
}, {
|
||||||
|
'name': 'another target'
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
source_count = models.ManyToManySource.objects.count()
|
||||||
|
target_count = models.ManyToManyTarget.objects.count()
|
||||||
|
|
||||||
|
deserialize = ManyToManySourceSerializer(data=data)
|
||||||
|
self.assertTrue(deserialize.is_valid(), deserialize.errors)
|
||||||
|
deserialize.save()
|
||||||
|
self.assertEqual(models.ManyToManySource.objects.count(), source_count + 1)
|
||||||
|
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count + 2)
|
||||||
|
|
||||||
|
def test_update_with_nested_many_to_many_relation(self):
|
||||||
|
class ManyToManyTargetSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = models.ManyToManyTarget
|
||||||
|
|
||||||
|
class ManyToManySourceSerializer(serializers.ModelSerializer):
|
||||||
|
targets = ManyToManyTargetSerializer(many=True, allow_add_remove=True)
|
||||||
|
class Meta:
|
||||||
|
model = models.ManyToManySource
|
||||||
|
|
||||||
|
source = models.ManyToManySource.objects.create(name='source')
|
||||||
|
target1 = models.ManyToManyTarget.objects.create(name='target1')
|
||||||
|
target2 = models.ManyToManyTarget.objects.create(name='target2')
|
||||||
|
source.targets = [target1]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'id': source.id,
|
||||||
|
'name': source.name + '0',
|
||||||
|
'targets': [{
|
||||||
|
'id': target1.id,
|
||||||
|
'name': target1.name + '1',
|
||||||
|
}, {
|
||||||
|
'id': target2.id,
|
||||||
|
'name': target2.name + '2',
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
source_count = models.ManyToManySource.objects.count()
|
||||||
|
target_count = models.ManyToManyTarget.objects.count()
|
||||||
|
|
||||||
|
deserialize = ManyToManySourceSerializer(data=data, instance=source)
|
||||||
|
self.assertTrue(deserialize.is_valid(), deserialize.errors)
|
||||||
|
deserialize.save()
|
||||||
|
self.assertEqual(models.ManyToManySource.objects.count(), source_count)
|
||||||
|
self.assertEqual(models.ManyToManyTarget.objects.count(), target_count)
|
||||||
|
|
||||||
|
# Were the models updated ?
|
||||||
|
self.assertEqual(source.name, 'source0')
|
||||||
|
alt_target1 = models.ManyToManyTarget.objects.get(id=target1.id)
|
||||||
|
self.assertEqual(alt_target1.name, target1.name + '1')
|
||||||
|
alt_target2 = models.ManyToManyTarget.objects.get(id=target2.id)
|
||||||
|
self.assertEqual(alt_target2.name, target2.name + '2')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user