Merge pull request #735 from tomchristie/one-to-one-nested-wip

One to one writable nested model serializers (wip)
This commit is contained in:
Tom Christie 2013-03-16 00:41:54 -07:00
commit ee20cf806b
3 changed files with 155 additions and 118 deletions

View File

@ -26,13 +26,17 @@ class NestedValidationError(ValidationError):
if the messages are a list of error messages. if the messages are a list of error messages.
In the case of nested serializers, where the parent has many children, In the case of nested serializers, where the parent has many children,
then the child's `serializer.errors` will be a list of dicts. then the child's `serializer.errors` will be a list of dicts. In the case
of a single child, the `serializer.errors` will be a dict.
We need to override the default behavior to get properly nested error dicts. We need to override the default behavior to get properly nested error dicts.
""" """
def __init__(self, message): def __init__(self, message):
self.messages = message if isinstance(message, dict):
self.messages = [message]
else:
self.messages = message
class DictWithMetadata(dict): class DictWithMetadata(dict):
@ -311,8 +315,8 @@ class BaseSerializer(WritableField):
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
""" """
Override default so that we can apply ModelSerializer as a nested Override default so that the serializer can be used as a nested field
field to relationships. across relationships.
""" """
if self.source == '*': if self.source == '*':
return self.to_native(obj) return self.to_native(obj)
@ -344,6 +348,10 @@ class BaseSerializer(WritableField):
return self.to_native(value) return self.to_native(value)
def field_from_native(self, data, files, field_name, into): def field_from_native(self, data, files, field_name, into):
"""
Override default so that the serializer can be used as a writable
nested field across relationships.
"""
if self.read_only: if self.read_only:
return return
@ -354,15 +362,14 @@ class BaseSerializer(WritableField):
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
return return
if self.parent.object: # Set the serializer object if it exists
# Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = getattr(self.parent.object, field_name)
self.object = obj
if value in (None, ''): if value in (None, ''):
into[(self.source or field_name)] = None into[(self.source or field_name)] = None
else: else:
kwargs = { kwargs = {
'instance': obj,
'data': value, 'data': value,
'context': self.context, 'context': self.context,
'partial': self.partial, 'partial': self.partial,
@ -371,7 +378,6 @@ class BaseSerializer(WritableField):
serializer = self.__class__(**kwargs) serializer = self.__class__(**kwargs)
if serializer.is_valid(): if serializer.is_valid():
self.object = serializer.object
into[self.source or field_name] = serializer.object into[self.source or field_name] = serializer.object
else: else:
# Propagate errors up to our parent # Propagate errors up to our parent
@ -630,33 +636,43 @@ class ModelSerializer(Serializer):
""" """
Restore the model instance. Restore the model instance.
""" """
self.m2m_data = {} m2m_data = {}
self.related_data = {} related_data = {}
meta = self.opts.model._meta
# Reverse fk relations # Reverse fk or one-to-one relations
for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.field.related_query_name() field_name = obj.field.related_query_name()
if field_name in attrs: if field_name in attrs:
self.related_data[field_name] = attrs.pop(field_name) related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations # Reverse m2m relations
for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.field.related_query_name() field_name = obj.field.related_query_name()
if field_name in attrs: if field_name in attrs:
self.m2m_data[field_name] = attrs.pop(field_name) m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations # Forward m2m relations
for field in self.opts.model._meta.many_to_many: for field in meta.many_to_many:
if field.name in attrs: if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name) m2m_data[field.name] = attrs.pop(field.name)
# Update an existing instance...
if instance is not None: if instance is not None:
for key, val in attrs.items(): for key, val in attrs.items():
setattr(instance, key, val) setattr(instance, key, val)
# ...or create a new instance
else: else:
instance = self.opts.model(**attrs) instance = self.opts.model(**attrs)
# Any relations that cannot be set until we've
# saved the model get hidden away on these
# private attributes, so we can deal with them
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
return instance return instance
def from_native(self, data, files): def from_native(self, data, files):
@ -673,15 +689,24 @@ class ModelSerializer(Serializer):
""" """
obj.save(**kwargs) obj.save(**kwargs)
if getattr(self, 'm2m_data', None): if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in self.m2m_data.items(): for accessor_name, object_list in obj._m2m_data.items():
setattr(self.object, accessor_name, object_list) setattr(obj, accessor_name, object_list)
self.m2m_data = {} del(obj._m2m_data)
if getattr(self, 'related_data', None): if getattr(obj, '_related_data', None):
for accessor_name, object_list in self.related_data.items(): for accessor_name, related in obj._related_data.items():
setattr(self.object, accessor_name, object_list) if related is None:
self.related_data = {} previous = getattr(obj, accessor_name, related)
if previous:
previous.delete()
elif isinstance(related, models.Model):
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related, fk_field, obj)
self.save_object(related)
else:
setattr(obj, accessor_name, related)
del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):

View File

@ -1,115 +1,125 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
class ForeignKeySourceSerializer(serializers.ModelSerializer): class OneToOneTarget(models.Model):
name = models.CharField(max_length=100)
class OneToOneTargetSource(models.Model):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='target_source')
class OneToOneSource(models.Model):
name = models.CharField(max_length=100)
target_source = models.OneToOneField(OneToOneTargetSource, related_name='source')
class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
depth = 1 model = OneToOneSource
model = ForeignKeySource exclude = ('target_source', )
class FlatForeignKeySourceSerializer(serializers.ModelSerializer): class OneToOneTargetSourceSerializer(serializers.ModelSerializer):
class Meta: source = OneToOneSourceSerializer()
model = ForeignKeySource
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = FlatForeignKeySourceSerializer(many=True)
class Meta: class Meta:
model = ForeignKeyTarget model = OneToOneTargetSource
exclude = ('target', )
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class OneToOneTargetSerializer(serializers.ModelSerializer):
class Meta: target_source = OneToOneTargetSourceSerializer()
depth = 1
model = NullableForeignKeySource
class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableOneToOneSource
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = NullableOneToOneSourceSerializer()
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
class ReverseForeignKeyTests(TestCase): class NestedOneToOneTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4): for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) target = OneToOneTarget(name='target-%d' % idx)
target.save()
target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target)
target_source.save()
source = OneToOneSource(name='source-%d' % idx, target_source=target_source)
source.save() source.save()
def test_foreign_key_retrieve(self): def test_one_to_one_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
]
self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1},
]},
{'id': 2, 'name': 'target-2', 'sources': [
]}
]
self.assertEqual(serializer.data, expected)
class NestedNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 3, 'name': 'source-3', 'target': None},
]
self.assertEqual(serializer.data, expected)
class NestedNullableOneToOneTests(TestCase):
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) serializer = OneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self):
data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}},
{'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]})
def test_one_to_one_update(self):
data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
]
self.assertEqual(serializer.data, expected)
def test_one_to_one_delete(self):
data = {'id': 3, 'name': 'target-3', 'target_source': None}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
# Ensure (target_source 3, source 3) are deleted,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': None}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)

View File

@ -1,5 +1,7 @@
""" """
Tests to cover nested serializers. Tests to cover nested serializers.
Doesn't cover model serializers.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
@ -124,7 +126,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, order, title, duration): def __init__(self, order, title, duration):
self.order, self.title, self.duration = order, title, duration self.order, self.title, self.duration = order, title, duration
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.order == other.order and self.order == other.order and
self.title == other.title and self.title == other.title and
@ -135,7 +137,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, album_name, artist, tracks): def __init__(self, album_name, artist, tracks):
self.album_name, self.artist, self.tracks = album_name, artist, tracks self.album_name, self.artist, self.tracks = album_name, artist, tracks
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.album_name == other.album_name and self.album_name == other.album_name and
self.artist == other.artist and self.artist == other.artist and