diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 043ec3f7a..93e0751b7 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -70,6 +70,9 @@ def api_view(http_method_names=None): WrappedAPIView.permission_classes = getattr(func, 'permission_classes', APIView.permission_classes) + WrappedAPIView.content_negotiation_class = getattr(func, 'content_negotiation_class', + APIView.content_negotiation_class) + WrappedAPIView.metadata_class = getattr(func, 'metadata_class', APIView.metadata_class) @@ -119,6 +122,13 @@ def permission_classes(permission_classes): return decorator +def content_negotiation_class(content_negotiation_class): + def decorator(func): + func.content_negotiation_class = content_negotiation_class + return func + return decorator + + def metadata_class(metadata_class): def decorator(func): func.metadata_class = metadata_class diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 25b4cbe44..0c070bc10 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -6,10 +6,11 @@ from django.test import TestCase from rest_framework import status from rest_framework.authentication import BasicAuthentication from rest_framework.decorators import ( - action, api_view, authentication_classes, metadata_class, parser_classes, - permission_classes, renderer_classes, schema, throttle_classes, - versioning_class + action, api_view, authentication_classes, content_negotiation_class, + metadata_class, parser_classes, permission_classes, renderer_classes, + schema, throttle_classes, versioning_class ) +from rest_framework.negotiation import BaseContentNegotiation from rest_framework.parsers import JSONParser from rest_framework.permissions import IsAuthenticated from rest_framework.renderers import JSONRenderer @@ -174,6 +175,21 @@ class DecoratorTestCase(TestCase): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} + def test_content_negotiation(self): + class CustomContentNegotiation(BaseContentNegotiation): + def select_renderer(self, request, renderers, format_suffix): + assert request.META['HTTP_ACCEPT'] == 'custom/type' + return (renderers[0], renderers[0].media_type) + + @api_view(["GET"]) + @content_negotiation_class(CustomContentNegotiation) + def view(request): + return Response({}) + + request = self.factory.get('/', HTTP_ACCEPT='custom/type') + response = view(request) + assert response.status_code == status.HTTP_200_OK + def test_schema(self): """ Checks CustomSchema class is set on view