mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-16 19:41:06 +03:00
Merge pull request #261 from j4mie/improved-view-decorators
First stab at new function-based view decorators
This commit is contained in:
commit
622e001e0b
|
@ -1,5 +1,4 @@
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from django.http import Http404
|
|
||||||
from django.utils.decorators import available_attrs
|
from django.utils.decorators import available_attrs
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
from rest_framework import exceptions
|
from rest_framework import exceptions
|
||||||
|
@ -7,47 +6,78 @@ from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
def api_view(allowed_methods):
|
def api_view(http_method_names):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Decorator for function based views.
|
Decorator that converts a function-based view into an APIView subclass.
|
||||||
|
Takes a list of allowed methods for the view as an argument.
|
||||||
@api_view(['GET', 'POST'])
|
|
||||||
def my_view(request):
|
|
||||||
# request will be an instance of `Request`
|
|
||||||
# `Response` objects will have .request set automatically
|
|
||||||
# APIException instances will be handled
|
|
||||||
"""
|
"""
|
||||||
allowed_methods = [method.upper() for method in allowed_methods]
|
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func, assigned=available_attrs(func))
|
|
||||||
def inner(request, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
|
|
||||||
request = Request(request)
|
class WrappedAPIView(APIView):
|
||||||
|
pass
|
||||||
|
|
||||||
if request.method not in allowed_methods:
|
WrappedAPIView.http_method_names = [method.lower() for method in http_method_names]
|
||||||
raise exceptions.MethodNotAllowed(request.method)
|
|
||||||
|
|
||||||
response = func(request, *args, **kwargs)
|
def handler(self, *args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if isinstance(response, Response):
|
for method in http_method_names:
|
||||||
response.request = request
|
setattr(WrappedAPIView, method.lower(), handler)
|
||||||
if api_settings.FORMAT_SUFFIX_KWARG:
|
|
||||||
response.format = kwargs.get(api_settings.FORMAT_SUFFIX_KWARG, None)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except exceptions.APIException as exc:
|
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
|
||||||
return Response({'detail': exc.detail}, status=exc.status_code)
|
APIView.renderer_classes)
|
||||||
|
|
||||||
except Http404 as exc:
|
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
|
||||||
return Response({'detail': 'Not found'},
|
APIView.parser_classes)
|
||||||
status=status.HTTP_404_NOT_FOUND)
|
|
||||||
|
|
||||||
except PermissionDenied as exc:
|
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
|
||||||
return Response({'detail': 'Permission denied'},
|
APIView.authentication_classes)
|
||||||
status=status.HTTP_403_FORBIDDEN)
|
|
||||||
return inner
|
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
|
||||||
|
APIView.throttle_classes)
|
||||||
|
|
||||||
|
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
|
||||||
|
APIView.permission_classes)
|
||||||
|
|
||||||
|
return WrappedAPIView.as_view()
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def renderer_classes(renderer_classes):
|
||||||
|
def decorator(func):
|
||||||
|
func.renderer_classes = renderer_classes
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def parser_classes(parser_classes):
|
||||||
|
def decorator(func):
|
||||||
|
func.parser_classes = parser_classes
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def authentication_classes(authentication_classes):
|
||||||
|
def decorator(func):
|
||||||
|
func.authentication_classes = authentication_classes
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def throttle_classes(throttle_classes):
|
||||||
|
def decorator(func):
|
||||||
|
func.throttle_classes = throttle_classes
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def permission_classes(permission_classes):
|
||||||
|
def decorator(func):
|
||||||
|
func.permission_classes = permission_classes
|
||||||
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
107
rest_framework/tests/decorators.py
Normal file
107
rest_framework/tests/decorators.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.compat import RequestFactory
|
||||||
|
from rest_framework.renderers import JSONRenderer
|
||||||
|
from rest_framework.parsers import JSONParser
|
||||||
|
from rest_framework.authentication import BasicAuthentication
|
||||||
|
from rest_framework.throttling import SimpleRateThottle
|
||||||
|
from rest_framework.permissions import IsAuthenticated
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
from rest_framework.decorators import (
|
||||||
|
api_view,
|
||||||
|
renderer_classes,
|
||||||
|
parser_classes,
|
||||||
|
authentication_classes,
|
||||||
|
throttle_classes,
|
||||||
|
permission_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoratorTestCase(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
|
||||||
|
def _finalize_response(self, request, response, *args, **kwargs):
|
||||||
|
print "HAI"
|
||||||
|
response.request = request
|
||||||
|
return APIView.finalize_response(self, request, response, *args, **kwargs)
|
||||||
|
|
||||||
|
def test_wrap_view(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
def view(request):
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(view.cls_instance, APIView))
|
||||||
|
|
||||||
|
def test_calling_method(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
def view(request):
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
request = self.factory.post('/')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, 405)
|
||||||
|
|
||||||
|
def test_renderer_classes(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
@renderer_classes([JSONRenderer])
|
||||||
|
def view(request):
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
response = view(request)
|
||||||
|
self.assertTrue(isinstance(response.renderer, JSONRenderer))
|
||||||
|
|
||||||
|
def test_parser_classes(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
@parser_classes([JSONParser])
|
||||||
|
def view(request):
|
||||||
|
self.assertEqual(request.parser_classes, [JSONParser])
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
view(request)
|
||||||
|
|
||||||
|
def test_authentication_classes(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
@authentication_classes([BasicAuthentication])
|
||||||
|
def view(request):
|
||||||
|
self.assertEqual(request.authentication_classes, [BasicAuthentication])
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
view(request)
|
||||||
|
|
||||||
|
def test_permission_classes(self):
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
@permission_classes([IsAuthenticated])
|
||||||
|
def view(request):
|
||||||
|
self.assertEqual(request.permission_classes, [IsAuthenticated])
|
||||||
|
return Response({})
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
view(request)
|
||||||
|
|
||||||
|
# Doesn't look like this bits are working quite yet
|
||||||
|
|
||||||
|
# def test_throttle_classes(self):
|
||||||
|
|
||||||
|
# @api_view(['GET'])
|
||||||
|
# @throttle_classes([SimpleRateThottle])
|
||||||
|
# def view(request):
|
||||||
|
# self.assertEqual(request.throttle_classes, [SimpleRateThottle])
|
||||||
|
# return Response({})
|
||||||
|
|
||||||
|
# request = self.factory.get('/')
|
||||||
|
# view(request)
|
Loading…
Reference in New Issue
Block a user