diff --git a/rest_framework/middleware.py b/rest_framework/middleware.py new file mode 100644 index 000000000..1385c7694 --- /dev/null +++ b/rest_framework/middleware.py @@ -0,0 +1,26 @@ +from django.core.exceptions import ImproperlyConfigured + +from rest_framework.settings import api_settings +from rest_framework.views import APIView + +try: + from django.contrib.auth.middleware import \ + LoginRequiredMiddleware as DjangoLoginRequiredMiddleware +except ImportError: + DjangoLoginRequiredMiddleware = None + + +if DjangoLoginRequiredMiddleware: + class LoginRequiredMiddleware(DjangoLoginRequiredMiddleware): + def process_view(self, request, view_func, view_args, view_kwargs): + if ( + hasattr(view_func, "cls") + and issubclass(view_func.cls, APIView) + ): + if 'rest_framework.permissions.AllowAny' in api_settings.DEFAULT_PERMISSION_CLASSES: + raise ImproperlyConfigured( + "You cannot use 'rest_framework.permissions.AllowAny' in `DEFAULT_PERMISSION_CLASSES` " + "with `LoginRequiredMiddleware`." + ) + return None + return super().process_view(request, view_func, view_args, view_kwargs) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index b8733b7dd..e1c1de6f6 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,11 +3,12 @@ import unittest import django from django.contrib.auth.models import User -from django.http import HttpRequest +from django.http import HttpRequest, HttpResponse from django.test import override_settings from django.urls import path +from django.views import View -from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import HTTP_HEADER_ENCODING, status from rest_framework.authentication import ( BasicAuthentication, TokenAuthentication ) @@ -18,15 +19,28 @@ from rest_framework.test import APITestCase from rest_framework.views import APIView -class PostView(APIView): +class PostAPIView(APIView): def post(self, request): return Response(data=request.data, status=200) +class GetAPIView(APIView): + def get(self, request): + return Response(data={"status": "ok"}, status=200) + + +class GetView(View): + def get(self, request): + return HttpResponse("OK", status=200) + + urlpatterns = [ - path('auth', APIView.as_view(authentication_classes=(TokenAuthentication,))), - path('basic', APIView.as_view(authentication_classes=(BasicAuthentication,))), - path('post', PostView.as_view()), + path('api/auth', APIView.as_view(authentication_classes=(TokenAuthentication,))), + path('api/post', PostAPIView.as_view()), + path('api/get', GetAPIView.as_view()), + path('api/basic', GetAPIView.as_view(authentication_classes=(BasicAuthentication,))), + path('api/token', GetAPIView.as_view(authentication_classes=(TokenAuthentication,))), + path('get', GetView.as_view()), ] @@ -73,14 +87,14 @@ class TestMiddleware(APITestCase): key = 'abcd1234' Token.objects.create(key=key, user=user) - self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) + self.client.get('/api/auth', HTTP_AUTHORIZATION='Token %s' % key) @override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',)) def test_middleware_can_access_request_post_when_processing_response(self): - response = self.client.post('/post', {'foo': 'bar'}) + response = self.client.post('/api/post', {'foo': 'bar'}) assert response.status_code == 200 - response = self.client.post('/post', {'foo': 'bar'}, format='json') + response = self.client.post('/api/post', {'foo': 'bar'}, format='json') assert response.status_code == 200 @@ -88,36 +102,56 @@ class TestMiddleware(APITestCase): @override_settings( ROOT_URLCONF='tests.test_middleware', MIDDLEWARE=( + # Needed for AuthenticationMiddleware 'django.contrib.sessions.middleware.SessionMiddleware', + # Needed for LoginRequiredMiddleware 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.auth.middleware.LoginRequiredMiddleware', + 'rest_framework.middleware.LoginRequiredMiddleware', ), + REST_FRAMEWORK={ + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ], + } ) class TestLoginRequiredMiddleware(APITestCase): - def test_redirects_to_login_when_user_is_anonymous(self): - response = self.client.post('/post') - self.assertRedirects(response, '/accounts/login/?next=/post', fetch_redirect_response=False) + def test_unauthorized_when_user_is_anonymous_on_public_view(self): + response = self.client.get('/api/get') + assert response.status_code == status.HTTP_401_UNAUTHORIZED - def test_process_request_when_session_authenticated(self): + def test_unauthorized_when_user_is_anonymous_on_basic_auth_view(self): + response = self.client.get('/api/basic') + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_unauthorized_when_user_is_anonymous_on_token_auth_view(self): + response = self.client.get('/api/token') + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_allows_request_when_session_authenticated(self): user = User.objects.create_user('john', 'john@example.com', 'password') self.client.force_login(user) - response = self.client.post('/post') - assert response.status_code == 200 + response = self.client.get('/api/get') + assert response.status_code == status.HTTP_200_OK - def test_compat_with_token_auth(self): + def test_allows_request_when_token_authenticated(self): user = User.objects.create_user('john', 'john@example.com', 'password') key = 'abcd1234' Token.objects.create(key=key, user=user) - response = self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) - assert response.status_code == 200 + response = self.client.get('/api/token', headers={"Authorization": f'Token {key}'}) + assert response.status_code == status.HTTP_200_OK - def test_compat_with_basic_auth(self): + def test_allows_request_when_basic_authenticated(self): user = User.objects.create_user('john', 'john@example.com', 'password') credentials = ('%s:%s' % (user.username, user.password)) base64_credentials = base64.b64encode( credentials.encode(HTTP_HEADER_ENCODING) ).decode(HTTP_HEADER_ENCODING) - response = self.client.get('/basic', HTTP_AUTHORIZATION='Basic %s' % base64_credentials) - assert response.status_code == 200 + auth = f'Basic {base64_credentials}' + response = self.client.get('/api/basic', headers={"Authorization": auth}) + assert response.status_code == status.HTTP_200_OK + + def test_works_as_base_middleware_for_django_view(self): + response = self.client.get('/get') + self.assertRedirects(response, '/accounts/login/?next=/get', fetch_redirect_response=False)