From edf30cff01eb4a5e33f807b2933c4b35632e96e5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Jun 2016 17:01:03 +0100 Subject: [PATCH] More robust form rendering in the browsable API --- rest_framework/renderers.py | 50 +++++++++++--------- tests/browsable_api/test_form_rendering.py | 53 ++++++++++++++++++++++ 2 files changed, 81 insertions(+), 22 deletions(-) create mode 100644 tests/browsable_api/test_form_rendering.py diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 264f7ac3b..7ca680e74 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -472,31 +472,37 @@ class BrowsableAPIRenderer(BaseRenderer): return if existing_serializer is not None: - serializer = existing_serializer - else: - if has_serializer: - if method in ('PUT', 'PATCH'): - serializer = view.get_serializer(instance=instance, **kwargs) - else: - serializer = view.get_serializer(**kwargs) + try: + return self.render_form_for_serializer(existing_serializer) + except TypeError: + pass + + if has_serializer: + if method in ('PUT', 'PATCH'): + serializer = view.get_serializer(instance=instance, **kwargs) else: - # at this point we must have a serializer_class - if method in ('PUT', 'PATCH'): - serializer = self._get_serializer(view.serializer_class, view, - request, instance=instance, **kwargs) - else: - serializer = self._get_serializer(view.serializer_class, view, - request, **kwargs) + serializer = view.get_serializer(**kwargs) + else: + # at this point we must have a serializer_class + if method in ('PUT', 'PATCH'): + serializer = self._get_serializer(view.serializer_class, view, + request, instance=instance, **kwargs) + else: + serializer = self._get_serializer(view.serializer_class, view, + request, **kwargs) - if hasattr(serializer, 'initial_data'): - serializer.is_valid() + return self.render_form_for_serializer(serializer) - form_renderer = self.form_renderer_class() - return form_renderer.render( - serializer.data, - self.accepted_media_type, - {'style': {'template_pack': 'rest_framework/horizontal'}} - ) + def render_form_for_serializer(self, serializer): + if hasattr(serializer, 'initial_data'): + serializer.is_valid() + + form_renderer = self.form_renderer_class() + return form_renderer.render( + serializer.data, + self.accepted_media_type, + {'style': {'template_pack': 'rest_framework/horizontal'}} + ) def get_raw_data_form(self, data, view, method, request): """ diff --git a/tests/browsable_api/test_form_rendering.py b/tests/browsable_api/test_form_rendering.py new file mode 100644 index 000000000..5a31ae0dd --- /dev/null +++ b/tests/browsable_api/test_form_rendering.py @@ -0,0 +1,53 @@ +from django.test import TestCase + +from rest_framework import generics, renderers, serializers, status +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from tests.models import BasicModel + +factory = APIRequestFactory() + + +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + +class ManyPostView(generics.GenericAPIView): + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def post(self, request, *args, **kwargs): + serializer = self.get_serializer(self.get_queryset(), many=True) + return Response(serializer.data, status.HTTP_200_OK) + + +class TestManyPostView(TestCase): + def setUp(self): + """ + Create 3 BasicModel instances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = ManyPostView.as_view() + + def test_post_many_post_view(self): + """ + POST request to a view that returns a list of objects should + still successfully return the browsable API with a rendered form. + + Regression test for https://github.com/tomchristie/django-rest-framework/pull/3164 + """ + data = {} + request = factory.post('/', data, format='json') + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 3)