diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e1c1de6f6..09c24d523 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -13,6 +13,7 @@ from rest_framework.authentication import ( BasicAuthentication, TokenAuthentication ) from rest_framework.authtoken.models import Token +from rest_framework.decorators import api_view from rest_framework.request import is_form_media_type from rest_framework.response import Response from rest_framework.test import APITestCase @@ -24,24 +25,34 @@ class PostAPIView(APIView): return Response(data=request.data, status=200) -class GetAPIView(APIView): - def get(self, request): - return Response(data={"status": "ok"}, status=200) +with override_settings( + REST_FRAMEWORK={ + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ], + } +): + class GetAPIView(APIView): + def get(self, request): + return Response(data={"status": "ok"}, status=200) + class GetView(View): + def get(self, request): + return HttpResponse("OK", status=200) -class GetView(View): - def get(self, request): + @api_view(['GET']) + def get_func_view(request): return HttpResponse("OK", status=200) - -urlpatterns = [ - path('api/auth', APIView.as_view(authentication_classes=(TokenAuthentication,))), - path('api/post', PostAPIView.as_view()), - path('api/get', GetAPIView.as_view()), - path('api/basic', GetAPIView.as_view(authentication_classes=(BasicAuthentication,))), - path('api/token', GetAPIView.as_view(authentication_classes=(TokenAuthentication,))), - path('get', GetView.as_view()), -] + urlpatterns = [ + path('api/auth', APIView.as_view(authentication_classes=(TokenAuthentication,))), + path('api/post', PostAPIView.as_view()), + path('api/get', GetAPIView.as_view()), + path('api/get-func', get_func_view), + path('api/basic', GetAPIView.as_view(authentication_classes=(BasicAuthentication,))), + path('api/token', GetAPIView.as_view(authentication_classes=(TokenAuthentication,))), + path('get', GetView.as_view()), + ] class RequestUserMiddleware: @@ -134,6 +145,13 @@ class TestLoginRequiredMiddleware(APITestCase): response = self.client.get('/api/get') assert response.status_code == status.HTTP_200_OK + def test_allows_request_when_authenticated_function_view(self): + user = User.objects.create_user('john', 'john@example.com', 'password') + self.client.force_login(user) + + response = self.client.get('/api/get-func') + assert response.status_code == status.HTTP_200_OK + def test_allows_request_when_token_authenticated(self): user = User.objects.create_user('john', 'john@example.com', 'password') key = 'abcd1234'