From 9dc26825bf37a024d607b34959d720e3ee305774 Mon Sep 17 00:00:00 2001 From: Partho Debnath Date: Mon, 31 Mar 2025 02:23:56 +0600 Subject: [PATCH] Update obtain_auth_token to authenticate using 'USERNAME_FIELD' and 'password' instead of 'username' and 'password' for both the built-in and custom User models --- docs/api-guide/authentication.md | 2 +- rest_framework/authtoken/serializers.py | 50 +++++++++++++++++-------- rest_framework/authtoken/views.py | 10 +++-- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 84e58bf4b..83c200d8c 100644 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -220,7 +220,7 @@ The `obtain_auth_token` view will return a JSON response when valid `username` a { 'token' : '9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b' } -Note that the default `obtain_auth_token` view explicitly uses JSON requests and responses, rather than using default renderer and parser classes in your settings. +Note that the default `obtain_auth_token` view explicitly uses JSON requests and responses, rather than using default renderer and parser classes in your settings. If you use a `custom User` model as `AUTH_USER_MODEL` in `settings.py`, authentication will use the `USERNAME_FIELD` and `password` defined in your custom model. By default, there are no permissions or throttling applied to the `obtain_auth_token` view. If you do wish to apply throttling you'll need to override the view class, and include them using the `throttle_classes` attribute. diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 63e64d668..100a0f895 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -1,32 +1,50 @@ -from django.contrib.auth import authenticate +from django.contrib.auth import authenticate, get_user_model from django.utils.translation import gettext_lazy as _ from rest_framework import serializers +USER_MODEL = get_user_model() + class AuthTokenSerializer(serializers.Serializer): - username = serializers.CharField( - label=_("Username"), - write_only=True - ) - password = serializers.CharField( - label=_("Password"), - style={'input_type': 'password'}, - trim_whitespace=False, - write_only=True - ) + def __init__(self, instance=None, data=None, **kwargs): + super().__init__(instance, data=data, **kwargs) + self.identifier_fiend_name = USER_MODEL.USERNAME_FIELD + if USER_MODEL.get_email_field_name() == self.identifier_fiend_name: + self.fields[self.identifier_fiend_name] = serializers.EmailField( + label=_(self.identifier_fiend_name.title()), + write_only=True + ) + else: + self.fields[self.identifier_fiend_name] = serializers.CharField( + label=_(self.identifier_fiend_name.title()), + write_only=True + ) + self.fields["password"] = serializers.CharField( + label=_("Password"), + style={'input_type': 'password'}, + trim_whitespace=False, + write_only=True + ) + token = serializers.CharField( label=_("Token"), read_only=True ) def validate(self, attrs): - username = attrs.get('username') + identifier_value = attrs.get(self.identifier_fiend_name) password = attrs.get('password') - if username and password: - user = authenticate(request=self.context.get('request'), - username=username, password=password) + if identifier_value and password: + credentials = { + self.identifier_fiend_name: identifier_value, + "password": password, + } + user = authenticate( + request=self.context.get('request'), + **credentials, + ) # The authenticate call simply returns None for is_active=False # users. (Assuming the default ModelBackend authentication @@ -35,7 +53,7 @@ class AuthTokenSerializer(serializers.Serializer): msg = _('Unable to log in with provided credentials.') raise serializers.ValidationError(msg, code='authorization') else: - msg = _('Must include "username" and "password".') + msg = _(f'Must include "{self.identifier_fiend_name}" and "password".') raise serializers.ValidationError(msg, code='authorization') attrs['user'] = user diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 50f9acbd9..aef3ed0c6 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -1,3 +1,5 @@ +from django.contrib.auth import get_user_model + from rest_framework import parsers, renderers from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer @@ -16,15 +18,17 @@ class ObtainAuthToken(APIView): serializer_class = AuthTokenSerializer if coreapi_schema.is_enabled(): + USER_MODEL = get_user_model() + identifier_field_name = USER_MODEL.USERNAME_FIELD schema = ManualSchema( fields=[ coreapi.Field( - name="username", + name=identifier_field_name, required=True, location='form', schema=coreschema.String( - title="Username", - description="Valid username for authentication", + title=identifier_field_name.title(), + description=f"Valid {identifier_field_name} for authentication", ), ), coreapi.Field(