From 008dafce178181855d66981cfacb908d013c5d1d Mon Sep 17 00:00:00 2001 From: toran billups Date: Sat, 15 Dec 2012 20:55:36 -0600 Subject: [PATCH] ManyPrimaryKeyRelatedField now supports create for one-to-many rel --- rest_framework/serializers.py | 12 +++++ rest_framework/tests/models.py | 5 ++ rest_framework/tests/serializer.py | 79 ++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8026205e5..276a7db79 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -496,12 +496,19 @@ class ModelSerializer(Serializer): Restore the model instance. """ self.m2m_data = {} + self.related_data = {} if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) return instance + # Related relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) + # Reverse relations for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() @@ -532,6 +539,11 @@ class ModelSerializer(Serializer): setattr(self.object, accessor_name, object_list) self.m2m_data = {} + if getattr(self, 'related_data', None): + for accessor_name, object_list in self.related_data.items(): + setattr(self.object, accessor_name, object_list) + self.related_data = {} + return self.object diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 428bf130d..0aa00d764 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -149,6 +149,11 @@ class BlogPostComment(RESTFrameworkModel): blog_post = models.ForeignKey(BlogPost) +class BlogPostRelatedComment(RESTFrameworkModel): + text = models.TextField() + blog_post = models.ForeignKey(BlogPost, related_name="comments") + + class Album(RESTFrameworkModel): title = models.CharField(max_length=100, unique=True) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 780177aa0..3c56f127f 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -816,3 +816,82 @@ class NestedSerializerContextTests(TestCase): # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data + + +class ManyPrimaryKeyRelatedCreateTest(TestCase): + + def test_create_is_valid_with_title_and_empty_comments_list(self): + data = {'title': 'foobar', 'comments': []} + serializer = self.build_model_serializer(data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_is_valid_with_title_and_comment(self): + data = {'title': 'foobar', 'comments': [self.comment.pk]} + serializer = self.build_model_serializer(data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_is_not_valid_when_title_is_empty_string(self): + data = {'title': '', 'comments': [self.comment.pk]} + serializer = self.build_model_serializer(data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'title': [u'This field is required.']}) + + def test_create_is_not_valid_when_title_present_but_no_comments(self): + data = {'title': 'foobar'} + serializer = self.build_model_serializer(data) + try: + self.assertEquals(serializer.is_valid(), False) + except TypeError as e: + self.assertEqual(e.message, "'NoneType' object is not iterable") + + def test_create_without_comment_returns_expected_json_result(self): + data = {'title': 'foobar', 'comments': []} + serializer = self.build_model_serializer(data) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + expected = { + 'title': u'foobar', + 'comments': [] + } + self.assertEqual(serializer.data, expected) + + def test_create_with_comment_returns_expected_json_result(self): + data = {'title': 'foobar', 'comments': [self.comment.pk]} + serializer = self.build_model_serializer(data) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + expected = { + 'title': u'foobar', + 'comments': [self.comment.pk] + } + self.assertEqual(serializer.data, expected) + + @property + def comment(self): + if not hasattr(self, '_comment'): + from rest_framework.tests.models import BlogPostRelatedComment + self._comment = BlogPostRelatedComment.objects.create(text="I love this blog post", blog_post=self.post) + return self._comment + + @property + def post(self): + if not hasattr(self, '_post'): + from rest_framework.tests.models import BlogPost + self._post = BlogPost.objects.create(title="Test blog post") + return self._post + + def build_model_serializer(self, data): + from rest_framework.tests.models import BlogPost, BlogPostRelatedComment + + class BlogPostRelatedCommentSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPostRelatedComment + fields = ("text") + + class BlogPostSerializer(serializers.ModelSerializer): + comments = serializers.ManyPrimaryKeyRelatedField() + class Meta: + model = BlogPost + fields = ("title", "comments") + + return BlogPostSerializer(data=data)