diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 8409a83c8..ea6e1cb34 100644 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -34,6 +34,17 @@ The authentication schemes are always defined as a list of classes. REST framew If no class authenticates, `request.user` will be set to an instance of `django.contrib.auth.models.AnonymousUser`, and `request.auth` will be set to `None`. The value of `request.user` and `request.auth` for unauthenticated requests can be modified using the `UNAUTHENTICATED_USER` and `UNAUTHENTICATED_TOKEN` settings. +### MultiUserModelAuthentication +The `MultiUserModelAuthentication` class supports authentication for multiple user models. + +To use this authentication mechanism, add it to your `DEFAULT_AUTHENTICATION_CLASSES` in `settings.py`: + +```python +REST_FRAMEWORK = { + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_framework.authentication.MultiUserModelAuthentication', + ], +} ## Setting the authentication scheme diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 3f3bd2227..9269cfd21 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -230,3 +230,63 @@ class RemoteUserAuthentication(BaseAuthentication): user = authenticate(request=request, remote_user=request.META.get(self.header)) if user and user.is_active: return (user, None) +class MultiUserModelAuthentication(BaseAuthentication): + """ + Custom authentication to support multiple user models. + """ + + def authenticate(self, request): + """ + Authenticate the request for multiple user models. + Returns a tuple of (user, None) or raises an exception if authentication fails. + """ + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != b'basic': + return None + + if len(auth) == 1: + msg = _('Invalid basic header. No credentials provided.') + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = _('Invalid basic header. Credentials string should not contain spaces.') + raise exceptions.AuthenticationFailed(msg) + + try: + try: + auth_decoded = base64.b64decode(auth[1]).decode('utf-8') + except UnicodeDecodeError: + auth_decoded = base64.b64decode(auth[1]).decode('latin-1') + + userid, password = auth_decoded.split(':', 1) + except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): + msg = _('Invalid basic header. Credentials not correctly base64 encoded.') + raise exceptions.AuthenticationFailed(msg) + + return self.authenticate_credentials(userid, password, request) + + def authenticate_credentials(self, userid, password, request=None): + """ + Authenticate credentials for multiple user models. + """ + # List of user models to authenticate against + user_models = ['users.User', 'admins.AdminUser'] + + for model_name in user_models: + try: + UserModel = get_user_model() # Replace with custom logic for each model + credentials = {UserModel.USERNAME_FIELD: userid, 'password': password} + user = authenticate(request=request, **credentials) + + if user: + if not user.is_active: + raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) + return (user, None) + except Exception as e: + # Continue to next user model if the current one fails + continue + + raise exceptions.AuthenticationFailed(_('Invalid username/password for all user models.')) + + def authenticate_header(self, request): + return 'Basic realm="api"' diff --git a/rest_framework/settings.py b/rest_framework/settings.py index b0d7bacec..8d9c90678 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -37,7 +37,8 @@ DEFAULTS = { 'rest_framework.parsers.FormParser', 'rest_framework.parsers.MultiPartParser' ], - 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_framework.authentication.MultiUserModelAuthentication', 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' ], diff --git a/tests/authentication/test_authentication.py b/tests/authentication/test_authentication.py index 2f05ce7d1..950cfe0ef 100644 --- a/tests/authentication/test_authentication.py +++ b/tests/authentication/test_authentication.py @@ -6,7 +6,10 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import TestCase, override_settings from django.urls import include, path - +from rest_framework.test import APIClient +from django.test import TestCase +from users.models import User +from admins.models import AdminUser from rest_framework import ( HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status ) @@ -597,3 +600,20 @@ class RemoteUserAuthenticationUnitTests(TestCase): response = self.client.post('/remote-user/', REMOTE_USER=self.username) self.assertEqual(response.status_code, status.HTTP_200_OK) +class MultiUserModelAuthenticationTest(TestCase): + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user(username='user', password='userpass') + self.admin = AdminUser.objects.create_user(username='admin', password='adminpass') + + def test_user_authentication(self): + response = self.client.post('/api/token/', {'username': 'user', 'password': 'userpass'}) + self.assertEqual(response.status_code, 200) + + def test_admin_authentication(self): + response = self.client.post('/api/token/', {'username': 'admin', 'password': 'adminpass'}) + self.assertEqual(response.status_code, 200) + + def test_invalid_authentication(self): + response = self.client.post('/api/token/', {'username': 'invalid', 'password': 'invalid'}) + self.assertEqual(response.status_code, 401)