diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 9836c9660..fc81489e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,5 +1,4 @@ from functools import wraps -from django.http import Http404 from django.utils.decorators import available_attrs from django.core.exceptions import PermissionDenied from rest_framework import exceptions @@ -7,47 +6,78 @@ from rest_framework import status from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings +from rest_framework.views import APIView -def api_view(allowed_methods): +def api_view(http_method_names): + """ - Decorator for function based views. - - @api_view(['GET', 'POST']) - def my_view(request): - # request will be an instance of `Request` - # `Response` objects will have .request set automatically - # APIException instances will be handled + Decorator that converts a function-based view into an APIView subclass. + Takes a list of allowed methods for the view as an argument. """ - allowed_methods = [method.upper() for method in allowed_methods] def decorator(func): - @wraps(func, assigned=available_attrs(func)) - def inner(request, *args, **kwargs): - try: - request = Request(request) + class WrappedAPIView(APIView): + pass - if request.method not in allowed_methods: - raise exceptions.MethodNotAllowed(request.method) + WrappedAPIView.http_method_names = [method.lower() for method in http_method_names] - response = func(request, *args, **kwargs) + def handler(self, *args, **kwargs): + return func(*args, **kwargs) - if isinstance(response, Response): - response.request = request - if api_settings.FORMAT_SUFFIX_KWARG: - response.format = kwargs.get(api_settings.FORMAT_SUFFIX_KWARG, None) - return response + for method in http_method_names: + setattr(WrappedAPIView, method.lower(), handler) - except exceptions.APIException as exc: - return Response({'detail': exc.detail}, status=exc.status_code) + WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes', + APIView.renderer_classes) - except Http404 as exc: - return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND) + WrappedAPIView.parser_classes = getattr(func, 'parser_classes', + APIView.parser_classes) - except PermissionDenied as exc: - return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN) - return inner + WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', + APIView.authentication_classes) + + WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', + APIView.throttle_classes) + + WrappedAPIView.permission_classes = getattr(func, 'permission_classes', + APIView.permission_classes) + + return WrappedAPIView.as_view() + return decorator + + +def renderer_classes(renderer_classes): + def decorator(func): + func.renderer_classes = renderer_classes + return func + return decorator + + +def parser_classes(parser_classes): + def decorator(func): + func.parser_classes = parser_classes + return func + return decorator + + +def authentication_classes(authentication_classes): + def decorator(func): + func.authentication_classes = authentication_classes + return func + return decorator + + +def throttle_classes(throttle_classes): + def decorator(func): + func.throttle_classes = throttle_classes + return func + return decorator + + +def permission_classes(permission_classes): + def decorator(func): + func.permission_classes = permission_classes + return func return decorator diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py new file mode 100644 index 000000000..d41f05d43 --- /dev/null +++ b/rest_framework/tests/decorators.py @@ -0,0 +1,107 @@ +from django.test import TestCase +from rest_framework.response import Response +from rest_framework.compat import RequestFactory +from rest_framework.renderers import JSONRenderer +from rest_framework.parsers import JSONParser +from rest_framework.authentication import BasicAuthentication +from rest_framework.throttling import SimpleRateThottle +from rest_framework.permissions import IsAuthenticated +from rest_framework.views import APIView +from rest_framework.decorators import ( + api_view, + renderer_classes, + parser_classes, + authentication_classes, + throttle_classes, + permission_classes, +) + + +class DecoratorTestCase(TestCase): + + def setUp(self): + self.factory = RequestFactory() + + def _finalize_response(self, request, response, *args, **kwargs): + print "HAI" + response.request = request + return APIView.finalize_response(self, request, response, *args, **kwargs) + + def test_wrap_view(self): + + @api_view(['GET']) + def view(request): + return Response({}) + + self.assertTrue(isinstance(view.cls_instance, APIView)) + + def test_calling_method(self): + + @api_view(['GET']) + def view(request): + return Response({}) + + request = self.factory.get('/') + response = view(request) + self.assertEqual(response.status_code, 200) + + request = self.factory.post('/') + response = view(request) + self.assertEqual(response.status_code, 405) + + def test_renderer_classes(self): + + @api_view(['GET']) + @renderer_classes([JSONRenderer]) + def view(request): + return Response({}) + + request = self.factory.get('/') + response = view(request) + self.assertTrue(isinstance(response.renderer, JSONRenderer)) + + def test_parser_classes(self): + + @api_view(['GET']) + @parser_classes([JSONParser]) + def view(request): + self.assertEqual(request.parser_classes, [JSONParser]) + return Response({}) + + request = self.factory.get('/') + view(request) + + def test_authentication_classes(self): + + @api_view(['GET']) + @authentication_classes([BasicAuthentication]) + def view(request): + self.assertEqual(request.authentication_classes, [BasicAuthentication]) + return Response({}) + + request = self.factory.get('/') + view(request) + + def test_permission_classes(self): + + @api_view(['GET']) + @permission_classes([IsAuthenticated]) + def view(request): + self.assertEqual(request.permission_classes, [IsAuthenticated]) + return Response({}) + + request = self.factory.get('/') + view(request) + +# Doesn't look like this bits are working quite yet + +# def test_throttle_classes(self): + +# @api_view(['GET']) +# @throttle_classes([SimpleRateThottle]) +# def view(request): +# self.assertEqual(request.throttle_classes, [SimpleRateThottle]) +# return Response({}) + +# request = self.factory.get('/') +# view(request)