diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index c69756a43..972efd3cd 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -68,6 +68,9 @@ def api_view(http_method_names): WrappedAPIView.permission_classes = getattr(func, 'permission_classes', APIView.permission_classes) + WrappedAPIView.content_negotiation_class = getattr(func, 'content_negotiation_class', + APIView.content_negotiation_class) + return WrappedAPIView.as_view() return decorator @@ -107,6 +110,13 @@ def permission_classes(permission_classes): return decorator +def content_negotiation_class(content_negotiation_class): + def decorator(func): + func.content_negotiation_class = content_negotiation_class + return func + return decorator + + def link(**kwargs): """ Used to mark a method on a ViewSet that should be routed for GET requests. diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py index 195f0ba3e..eb98382b6 100644 --- a/rest_framework/tests/test_decorators.py +++ b/rest_framework/tests/test_decorators.py @@ -16,6 +16,7 @@ from rest_framework.decorators import ( authentication_classes, throttle_classes, permission_classes, + content_negotiation_class, ) @@ -155,3 +156,28 @@ class DecoratorTestCase(TestCase): response = view(request) self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_content_negotiation_class(self): + from rest_framework.renderers import BaseRenderer + from rest_framework.negotiation import BaseContentNegotiation + + class TestRenderer(BaseRenderer): + media_type = 'test' + format = 'test' + + class TestContentNegotiation(BaseContentNegotiation): + def select_parser(self, request, parsers): + return parsers[0] + + def select_renderer(self, request, renderers, format_suffix): + return (TestRenderer(), TestRenderer.media_type) + + @api_view(['GET']) + @content_negotiation_class(TestContentNegotiation) + def view(request): + self.assertIsInstance(request.accepted_renderer, TestRenderer) + return Response({}) + + request = self.factory.get('/') + view(request) +