diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1483cb568..5df8ced07 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -6,95 +6,78 @@ from rest_framework import status from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings -from rest_framwork.views import APIView +from rest_framework.views import APIView -class LazyViewCreator(object): +def api_view(http_method_names): """ - This class is responsible for dynamically creating an APIView subclass that - will wrap a function-based view. Instances of this class are created - by the function-based view decorators (below), and each decorator is - responsible for setting attributes on the instance that will eventually be - copied onto the final class-based view. The CBV gets created lazily the first - time it's needed, and then cached for future use. - - This is done so that the ordering of stacked decorators is irrelevant. + Decorator that converts a function-based view into an APIView subclass. + Takes a list of allowed methods for the view as an argument. """ - def __init__(self, wrapped_view): + def decorator(func): - self.wrapped_view = wrapped_view + class WrappedAPIView(APIView): + pass - # Each item in this dictionary will be copied onto the final - # class-based view that gets created when this object is called - self.final_view_attrs = { - 'http_method_names': APIView.http_method_names, - 'renderer_classes': APIView.renderer_classes, - 'parser_classes': APIView.parser_classes, - 'authentication_classes': APIView.authentication_classes, - 'throttle_classes': APIView.throttle_classes, - 'permission_classes': APIView.permission_classes, - } - self._cached_view = None + WrappedAPIView.http_method_names = [method.lower() for method in http_method_names] - def handler(self, *args, **kwargs): - return self.wrapped_view(*args, **kwargs) + def handler(self, *args, **kwargs): + return func(*args, **kwargs) - @property - def view(self): - """ - Accessor for the dynamically created class-based view. This will - be created if necessary and cached for next time. - """ + for method in http_method_names: + setattr(WrappedAPIView, method.lower(), handler) - if self._cached_view is None: + WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes', + APIView.renderer_classes) - class WrappedAPIView(APIView): - pass + WrappedAPIView.parser_classes = getattr(func, 'parser_classes', + APIView.parser_classes) - for attr, value in self.final_view_attrs.items(): - setattr(WrappedAPIView, attr, value) + WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', + APIView.authentication_classes) - # Attach the wrapped view function for each of the - # allowed HTTP methods - for method in WrappedAPIView.http_method_names: - setattr(WrappedAPIView, method.lower(), self.handler) + WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', + APIView.throttle_classes) - self._cached_view = WrappedAPIView.as_view() + WrappedAPIView.permission_classes = getattr(func, 'permission_classes', + APIView.permission_classes) - return self._cached_view - - def __call__(self, *args, **kwargs): - """ - This is the actual code that gets run per-request - """ - return self.view(*args, **kwargs) - - @staticmethod - def maybe_create(func_or_instance): - """ - If the argument is already an instance of LazyViewCreator, - just return it. Otherwise, create a new one. - """ - if isinstance(func_or_instance, LazyViewCreator): - return func_or_instance - return LazyViewCreator(func_or_instance) - - -def _create_attribute_setting_decorator(attribute, filter=lambda item: item): - def decorator(value): - def inner(func): - wrapper = LazyViewCreator.maybe_create(func) - wrapper.final_view_attrs[attribute] = filter(value) - return wrapper - return inner + return WrappedAPIView.as_view() return decorator -api_view = _create_attribute_setting_decorator('http_method_names', filter=lambda methods: [method.lower() for method in methods]) -renderer_classes = _create_attribute_setting_decorator('renderer_classes') -parser_classes = _create_attribute_setting_decorator('parser_classes') -authentication_classes = _create_attribute_setting_decorator('authentication_classes') -throttle_classes = _create_attribute_setting_decorator('throttle_classes') -permission_classes = _create_attribute_setting_decorator('permission_classes') +def renderer_classes(renderer_classes): + def decorator(func): + setattr(func, 'renderer_classes', renderer_classes) + return func + return decorator + + +def parser_classes(parser_classes): + def decorator(func): + setattr(func, 'parser_classes', parser_classes) + return func + return decorator + + +def authentication_classes(authentication_classes): + def decorator(func): + setattr(func, 'authentication_classes', authentication_classes) + return func + return decorator + + +def throttle_classes(throttle_classes): + def decorator(func): + setattr(func, 'throttle_classes', throttle_classes) + return func + return decorator + + +def permission_classes(permission_classes): + def decorator(func): + setattr(func, 'permission_classes', permission_classes) + return func + return decorator diff --git a/djangorestframework/tests/decorators.py b/rest_framework/tests/decorators.py similarity index 53% rename from djangorestframework/tests/decorators.py rename to rest_framework/tests/decorators.py index 0d3be8f3c..d41f05d43 100644 --- a/djangorestframework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,19 +1,19 @@ from django.test import TestCase -from djangorestframework.response import Response -from djangorestframework.compat import RequestFactory -from djangorestframework.renderers import JSONRenderer -from djangorestframework.parsers import JSONParser -from djangorestframework.authentication import BasicAuthentication -from djangorestframework.throttling import SimpleRateThottle -from djangorestframework.permissions import IsAuthenticated -from djangorestframework.decorators import ( +from rest_framework.response import Response +from rest_framework.compat import RequestFactory +from rest_framework.renderers import JSONRenderer +from rest_framework.parsers import JSONParser +from rest_framework.authentication import BasicAuthentication +from rest_framework.throttling import SimpleRateThottle +from rest_framework.permissions import IsAuthenticated +from rest_framework.views import APIView +from rest_framework.decorators import ( api_view, renderer_classes, parser_classes, authentication_classes, throttle_classes, permission_classes, - LazyViewCreator ) @@ -22,13 +22,18 @@ class DecoratorTestCase(TestCase): def setUp(self): self.factory = RequestFactory() + def _finalize_response(self, request, response, *args, **kwargs): + print "HAI" + response.request = request + return APIView.finalize_response(self, request, response, *args, **kwargs) + def test_wrap_view(self): @api_view(['GET']) def view(request): return Response({}) - self.assertTrue(isinstance(view, LazyViewCreator)) + self.assertTrue(isinstance(view.cls_instance, APIView)) def test_calling_method(self): @@ -46,57 +51,57 @@ class DecoratorTestCase(TestCase): def test_renderer_classes(self): - @renderer_classes([JSONRenderer]) @api_view(['GET']) + @renderer_classes([JSONRenderer]) def view(request): return Response({}) request = self.factory.get('/') response = view(request) - self.assertEqual(response.renderer_classes, [JSONRenderer]) + self.assertTrue(isinstance(response.renderer, JSONRenderer)) def test_parser_classes(self): - @parser_classes([JSONParser]) @api_view(['GET']) + @parser_classes([JSONParser]) def view(request): + self.assertEqual(request.parser_classes, [JSONParser]) return Response({}) request = self.factory.get('/') - response = view(request) - self.assertEqual(response.request.parser_classes, [JSONParser]) + view(request) def test_authentication_classes(self): - @authentication_classes([BasicAuthentication]) @api_view(['GET']) + @authentication_classes([BasicAuthentication]) def view(request): + self.assertEqual(request.authentication_classes, [BasicAuthentication]) return Response({}) request = self.factory.get('/') - response = view(request) - self.assertEqual(response.request.authentication_classes, [BasicAuthentication]) + view(request) -# Doesn't look like these bits are working quite yet + def test_permission_classes(self): + + @api_view(['GET']) + @permission_classes([IsAuthenticated]) + def view(request): + self.assertEqual(request.permission_classes, [IsAuthenticated]) + return Response({}) + + request = self.factory.get('/') + view(request) + +# Doesn't look like this bits are working quite yet # def test_throttle_classes(self): -# + +# @api_view(['GET']) # @throttle_classes([SimpleRateThottle]) -# @api_view(['GET']) -# def view(request): -# return Response({}) -# -# request = self.factory.get('/') -# response = view(request) -# self.assertEqual(response.request.throttle, [SimpleRateThottle]) - -# def test_permission_classes(self): - -# @permission_classes([IsAuthenticated]) -# @api_view(['GET']) # def view(request): +# self.assertEqual(request.throttle_classes, [SimpleRateThottle]) # return Response({}) # request = self.factory.get('/') -# response = view(request) -# self.assertEqual(response.request.permission_classes, [IsAuthenticated]) +# view(request)