diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 6b2c91db7..b8733b7dd 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,9 +1,16 @@ +import base64 +import unittest + +import django from django.contrib.auth.models import User from django.http import HttpRequest from django.test import override_settings from django.urls import path -from rest_framework.authentication import TokenAuthentication +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework.authentication import ( + BasicAuthentication, TokenAuthentication +) from rest_framework.authtoken.models import Token from rest_framework.request import is_form_media_type from rest_framework.response import Response @@ -18,6 +25,7 @@ class PostView(APIView): urlpatterns = [ path('auth', APIView.as_view(authentication_classes=(TokenAuthentication,))), + path('basic', APIView.as_view(authentication_classes=(BasicAuthentication,))), path('post', PostView.as_view()), ] @@ -74,3 +82,42 @@ class TestMiddleware(APITestCase): response = self.client.post('/post', {'foo': 'bar'}, format='json') assert response.status_code == 200 + + +@unittest.skipUnless(django.VERSION >= (5, 1), 'Only for Django 5.1+') +@override_settings( + ROOT_URLCONF='tests.test_middleware', + MIDDLEWARE=( + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.LoginRequiredMiddleware', + ), +) +class TestLoginRequiredMiddleware(APITestCase): + def test_redirects_to_login_when_user_is_anonymous(self): + response = self.client.post('/post') + self.assertRedirects(response, '/accounts/login/?next=/post', fetch_redirect_response=False) + + def test_process_request_when_session_authenticated(self): + user = User.objects.create_user('john', 'john@example.com', 'password') + self.client.force_login(user) + + response = self.client.post('/post') + assert response.status_code == 200 + + def test_compat_with_token_auth(self): + user = User.objects.create_user('john', 'john@example.com', 'password') + key = 'abcd1234' + Token.objects.create(key=key, user=user) + + response = self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) + assert response.status_code == 200 + + def test_compat_with_basic_auth(self): + user = User.objects.create_user('john', 'john@example.com', 'password') + credentials = ('%s:%s' % (user.username, user.password)) + base64_credentials = base64.b64encode( + credentials.encode(HTTP_HEADER_ENCODING) + ).decode(HTTP_HEADER_ENCODING) + response = self.client.get('/basic', HTTP_AUTHORIZATION='Basic %s' % base64_credentials) + assert response.status_code == 200