From 6b4bb48dd410d0a878b0142d351c7c41cd51f819 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 12 Mar 2013 13:33:02 +0000 Subject: [PATCH] Initial support for writable nested serialization (Not ModelSerializer) --- rest_framework/serializers.py | 70 +++++++++++++++++------ rest_framework/tests/serializer_nested.py | 62 ++++++++++++++++++++ 2 files changed, 115 insertions(+), 17 deletions(-) create mode 100644 rest_framework/tests/serializer_nested.py diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2ae7c215f..81619b3af 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,11 @@ from rest_framework.relations import * from rest_framework.fields import * +class NestedValidationError(ValidationError): + def __init__(self, message): + self.messages = message + + class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. @@ -98,7 +103,7 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): """ This is the Serializer implementation. We need to implement it as `BaseSerializer` due to metaclass magicks. @@ -303,33 +308,64 @@ class BaseSerializer(Field): return self.to_native(obj) try: - if self.source: - for component in self.source.split('.'): - obj = getattr(obj, component) - if is_simple_callable(obj): - obj = obj() - else: - obj = getattr(obj, field_name) - if is_simple_callable(obj): - obj = obj() + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break except ObjectDoesNotExist: return None - # If the object has an "all" method, assume it's a relationship - if is_simple_callable(getattr(obj, 'all', None)): - return [self.to_native(item) for item in obj.all()] + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] - if obj is None: + if value is None: return None if self.many is not None: many = self.many else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type)) + many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type)) if many: - return [self.to_native(item) for item in obj] - return self.to_native(obj) + return [self.to_native(item) for item in value] + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + try: + value = data[field_name] + except KeyError: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + if self.parent.object: + # Set the serializer object if it exists + obj = getattr(self.parent.object, field_name) + self.object = obj + + if value in (None, ''): + into[(self.source or field_name)] = None + else: + kwargs = { + 'data': value, + 'context': self.context, + 'partial': self.partial, + 'many': self.many + } + serializer = self.__class__(**kwargs) + + if serializer.is_valid(): + self.object = serializer.object + into[self.source or field_name] = serializer.object + else: + # Propagate errors up to our parent + raise NestedValidationError(serializer.errors) @property def errors(self): diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py new file mode 100644 index 000000000..c8987bc56 --- /dev/null +++ b/rest_framework/tests/serializer_nested.py @@ -0,0 +1,62 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + +class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + +class NestedSerializerTestCase(TestCase): + def test_nested_validation_success(self): + """ + Correct nested serialization should return the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + + serializer = AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + + def test_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + expected_errors = { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + + serializer = AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors)