diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index 4d9da4896..cf5cfb12d 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -1,9 +1,11 @@ from __future__ import unicode_literals from django.db import models -from django.test import TestCase +from django.test import TestCase, RequestFactory from rest_framework import serializers +from rest_framework import generics +import json -from .models import OneToOneTarget +from .models import OneToOneTarget, BlogPost, BlogPostComment class OneToOneSource(models.Model): @@ -324,3 +326,90 @@ class ReverseNestedOneToManyTests(TestCase): ] self.assertEqual(serializer.data, expected) + + +class NestedWritableSerializerUpdateTests(TestCase): + + def test_post_nested_collection(self): + class BlogPostCommentSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPostComment + read_only_fields = ('id', 'blog_post') + + class BlogPostSerializer(serializers.ModelSerializer): + comments = BlogPostCommentSerializer( + many=True, required=False, + allow_add_remove=True, + source='blogpostcomment_set', + ) + class Meta: + model = BlogPost + fields = ('id', 'title', 'comments') + + class BlogList(generics.RetrieveUpdateAPIView): + serializer_class = BlogPostSerializer + + def get_queryset(self): + return BlogPost.objects.prefetch_related( + 'blogpostcomment_set', + ).all() + + BlogPost.objects.create(pk=1, title="No text") + + view = BlogList.as_view() + factory = RequestFactory() + request_data = { + "comments": [ + { + "id": None, + "text": "First comment", + "blog_post": None, + }, + { + "id": None, + "text": "Second comment", + "blog_post": None, + }, + ], + "title": "Some text", + } + request_data = json.dumps(request_data) + + request = factory.put( + '/blogs/1/', data=request_data, content_type='application/json', + ) + + response = view(request, pk=1) + + expected_comments = BlogPostComment.objects.values().all() + self.assertEqual(len(expected_comments), 2) + self.assertTrue(all(x["blog_post_id"] == 1 for x in expected_comments)) + + response.render() + response_data = json.loads(response.content) + self.assertEqual(response.status_code, 200) + self.assertEqual(response_data["id"], 1) + self.assertEqual(response_data["title"], "Some text") + + response_comments = response_data["comments"] + self.assertEqual(len(response_comments), 2) + + self.assertEqual( + response_comments[0]["id"], + expected_comments[0]["id"] + ) + + self.assertEqual( + response_comments[1]["id"], + expected_comments[1]["id"] + ) + + self.assertEqual( + response_comments[0]["blog_post"], + expected_comments[0]["blog_post_id"] + ) + self.assertEqual( + response_comments[1]["blog_post"], + expected_comments[1]["blog_post_id"] + ) +