One-to-one writable, nested serializer support

This commit is contained in:
Mark Aaron Shirley 2013-03-12 20:59:25 -07:00
parent b6b686d285
commit 3006e3825f
3 changed files with 160 additions and 13 deletions

View File

@ -26,13 +26,17 @@ class NestedValidationError(ValidationError):
if the messages are a list of error messages. if the messages are a list of error messages.
In the case of nested serializers, where the parent has many children, In the case of nested serializers, where the parent has many children,
then the child's `serializer.errors` will be a list of dicts. then the child's `serializer.errors` will be a list of dicts. In the case
of a single child, the `serializer.errors` will be a dict.
We need to override the default behavior to get properly nested error dicts. We need to override the default behavior to get properly nested error dicts.
""" """
def __init__(self, message): def __init__(self, message):
self.messages = message if isinstance(message, dict):
self.messages = [message]
else:
self.messages = message
class DictWithMetadata(dict): class DictWithMetadata(dict):
@ -143,6 +147,7 @@ class BaseSerializer(WritableField):
self._data = None self._data = None
self._files = None self._files = None
self._errors = None self._errors = None
self._delete = False
##### #####
# Methods to determine which fields to use when (de)serializing objects. # Methods to determine which fields to use when (de)serializing objects.
@ -354,15 +359,19 @@ class BaseSerializer(WritableField):
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
return return
if self.parent.object: # Set the serializer object if it exists
# Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = getattr(self.parent.object, field_name)
self.object = obj
if value in (None, ''): if value in (None, ''):
into[(self.source or field_name)] = None if isinstance(self, ModelSerializer):
self._delete = True
self.object = obj
into[(self.source or field_name)] = self
else:
into[(self.source or field_name)] = None
else: else:
kwargs = { kwargs = {
'instance': obj,
'data': value, 'data': value,
'context': self.context, 'context': self.context,
'partial': self.partial, 'partial': self.partial,
@ -371,8 +380,10 @@ class BaseSerializer(WritableField):
serializer = self.__class__(**kwargs) serializer = self.__class__(**kwargs)
if serializer.is_valid(): if serializer.is_valid():
self.object = serializer.object if isinstance(serializer, ModelSerializer):
into[self.source or field_name] = serializer.object into[self.source or field_name] = serializer
else:
into[self.source or field_name] = serializer.object
else: else:
# Propagate errors up to our parent # Propagate errors up to our parent
raise NestedValidationError(serializer.errors) raise NestedValidationError(serializer.errors)
@ -664,10 +675,17 @@ class ModelSerializer(Serializer):
if instance: if instance:
return self.full_clean(instance) return self.full_clean(instance)
def save_object(self, obj): def save_object(self, obj, parent=None, fk_field=None):
""" """
Save the deserialized object and return it. Save the deserialized object and return it.
""" """
if self._delete:
obj.delete()
return
if parent and fk_field:
setattr(self.object, fk_field, parent)
obj.save() obj.save()
if getattr(self, 'm2m_data', None): if getattr(self, 'm2m_data', None):
@ -677,7 +695,11 @@ class ModelSerializer(Serializer):
if getattr(self, 'related_data', None): if getattr(self, 'related_data', None):
for accessor_name, object_list in self.related_data.items(): for accessor_name, object_list in self.related_data.items():
setattr(self.object, accessor_name, object_list) if isinstance(object_list, ModelSerializer):
fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name
object_list.save_object(object_list.object, parent=self.object, fk_field=fk_field)
else:
setattr(self.object, accessor_name, object_list)
self.related_data = {} self.related_data = {}

View File

@ -0,0 +1,125 @@
from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
class OneToOneTarget(models.Model):
name = models.CharField(max_length=100)
class OneToOneTargetSource(models.Model):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='target_source')
class OneToOneSource(models.Model):
name = models.CharField(max_length=100)
target_source = models.OneToOneField(OneToOneTargetSource, related_name='source')
class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneSource
exclude = ('target_source', )
class OneToOneTargetSourceSerializer(serializers.ModelSerializer):
source = OneToOneSourceSerializer()
class Meta:
model = OneToOneTargetSource
exclude = ('target', )
class OneToOneTargetSerializer(serializers.ModelSerializer):
target_source = OneToOneTargetSourceSerializer()
class Meta:
model = OneToOneTarget
class NestedOneToOneTests(TestCase):
def setUp(self):
for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx)
target.save()
target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target)
target_source.save()
source = OneToOneSource(name='source-%d' % idx, target_source=target_source)
source.save()
def test_one_to_one_retrieve(self):
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self):
data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}},
{'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]})
def test_one_to_one_update(self):
data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_delete(self):
data = {'id': 3, 'name': 'target-3', 'target_source': None}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
# Ensure (target_source 3, source 3) are deleted,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': None}
]
self.assertEqual(serializer.data, expected)

View File

@ -124,7 +124,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, order, title, duration): def __init__(self, order, title, duration):
self.order, self.title, self.duration = order, title, duration self.order, self.title, self.duration = order, title, duration
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.order == other.order and self.order == other.order and
self.title == other.title and self.title == other.title and
@ -135,7 +135,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, album_name, artist, tracks): def __init__(self, album_name, artist, tracks):
self.album_name, self.artist, self.tracks = album_name, artist, tracks self.album_name, self.artist, self.tracks = album_name, artist, tracks
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.album_name == other.album_name and self.album_name == other.album_name and
self.artist == other.artist and self.artist == other.artist and