django-rest-framework/tests/test_decorators.py

158 lines
4.7 KiB
Python
Raw Normal View History

from __future__ import unicode_literals
2012-09-14 19:07:07 +04:00
from django.test import TestCase
2012-09-27 00:47:19 +04:00
from rest_framework import status
2013-06-28 20:17:39 +04:00
from rest_framework.authentication import BasicAuthentication
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.renderers import JSONRenderer
2013-06-28 20:17:39 +04:00
from rest_framework.test import APIRequestFactory
2012-09-27 00:47:19 +04:00
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
from rest_framework.decorators import (
2012-09-14 19:07:07 +04:00
api_view,
renderer_classes,
parser_classes,
authentication_classes,
throttle_classes,
permission_classes,
)
class DecoratorTestCase(TestCase):
def setUp(self):
2013-06-28 20:17:39 +04:00
self.factory = APIRequestFactory()
2012-09-14 19:07:07 +04:00
def _finalize_response(self, request, response, *args, **kwargs):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
def test_api_view_incorrect(self):
"""
If @api_view is not applied correct, we should raise an assertion.
"""
@api_view
def view(request):
return Response()
request = self.factory.get('/')
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with self.assertRaises(AssertionError):
@api_view('GET')
def view(request):
return Response()
2012-09-14 19:07:07 +04:00
def test_calling_method(self):
@api_view(['GET'])
def view(request):
return Response({})
request = self.factory.get('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
2012-09-14 19:07:07 +04:00
request = self.factory.post('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_put_method(self):
@api_view(['GET', 'PUT'])
def view(request):
return Response({})
request = self.factory.put('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
2012-09-14 19:07:07 +04:00
def test_calling_patch_method(self):
@api_view(['GET', 'PATCH'])
def view(request):
return Response({})
request = self.factory.patch('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
2012-09-14 19:07:07 +04:00
def test_renderer_classes(self):
@api_view(['GET'])
@renderer_classes([JSONRenderer])
2012-09-14 19:07:07 +04:00
def view(request):
return Response({})
request = self.factory.get('/')
response = view(request)
self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
2012-09-14 19:07:07 +04:00
def test_parser_classes(self):
@api_view(['GET'])
@parser_classes([JSONParser])
2012-09-14 19:07:07 +04:00
def view(request):
2012-10-05 17:48:33 +04:00
self.assertEqual(len(request.parsers), 1)
self.assertTrue(isinstance(request.parsers[0],
JSONParser))
2012-09-14 19:07:07 +04:00
return Response({})
request = self.factory.get('/')
view(request)
2012-09-14 19:07:07 +04:00
def test_authentication_classes(self):
@api_view(['GET'])
2012-09-14 19:07:07 +04:00
@authentication_classes([BasicAuthentication])
def view(request):
2012-10-05 17:48:33 +04:00
self.assertEqual(len(request.authenticators), 1)
self.assertTrue(isinstance(request.authenticators[0],
BasicAuthentication))
return Response({})
request = self.factory.get('/')
view(request)
def test_permission_classes(self):
2012-09-14 19:07:07 +04:00
@api_view(['GET'])
@permission_classes([IsAuthenticated])
2012-09-14 19:07:07 +04:00
def view(request):
return Response({})
request = self.factory.get('/')
2012-09-27 00:47:19 +04:00
response = view(request)
2013-02-28 01:15:00 +04:00
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
2012-09-14 19:07:07 +04:00
2012-09-27 00:47:19 +04:00
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day'
2012-09-14 19:07:07 +04:00
2012-09-27 00:47:19 +04:00
@api_view(['GET'])
@throttle_classes([OncePerDayUserThrottle])
def view(request):
return Response({})
2012-09-14 19:07:07 +04:00
2012-09-27 00:47:19 +04:00
request = self.factory.get('/')
response = view(request)
2013-02-28 01:15:00 +04:00
self.assertEqual(response.status_code, status.HTTP_200_OK)
2012-09-14 19:07:07 +04:00
2012-09-27 00:47:19 +04:00
response = view(request)
2013-02-28 01:15:00 +04:00
self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)