From 3f0c04f88749e6006aa16a626bf574be217b6634 Mon Sep 17 00:00:00 2001 From: Jascha Geerds Date: Tue, 25 Jul 2017 13:35:26 +0200 Subject: [PATCH] Call Django's authenticate function with the request object MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As of Django 1.11 the `authenticate` function accepts a request as an additional argument. This commit fixes compatibility between newer Django versions and custom authentication backends which already depend on the request object. See also: [Django 1.11 release](https://docs.djangoproject.com/en/1.11/releases/1.11/) ``` authenticate() now passes a request argument to the authenticate() method of authentication backends. Support for methods that don’t accept request as the first positional argument will be removed in Django 2.1. ``` --- rest_framework/authentication.py | 12 +++++++----- rest_framework/authtoken/serializers.py | 5 +++-- rest_framework/authtoken/views.py | 3 ++- rest_framework/compat.py | 8 ++++++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index cb9608a3c..c2a722f0c 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -6,12 +6,13 @@ from __future__ import unicode_literals import base64 import binascii -from django.contrib.auth import authenticate, get_user_model +from django.contrib.auth import get_user_model from django.middleware.csrf import CsrfViewMiddleware from django.utils.six import text_type from django.utils.translation import ugettext_lazy as _ from rest_framework import HTTP_HEADER_ENCODING, exceptions +from rest_framework.compat import authenticate def get_authorization_header(request): @@ -83,17 +84,18 @@ class BasicAuthentication(BaseAuthentication): raise exceptions.AuthenticationFailed(msg) userid, password = auth_parts[0], auth_parts[2] - return self.authenticate_credentials(userid, password) + return self.authenticate_credentials(userid, password, request) - def authenticate_credentials(self, userid, password): + def authenticate_credentials(self, userid, password, request=None): """ - Authenticate the userid and password against username and password. + Authenticate the userid and password against username and password + with optional request for context. """ credentials = { get_user_model().USERNAME_FIELD: userid, 'password': password } - user = authenticate(**credentials) + user = authenticate(request=request, **credentials) if user is None: raise exceptions.AuthenticationFailed(_('Invalid username/password.')) diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 7590fdb75..301b6a0cb 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -1,7 +1,7 @@ -from django.contrib.auth import authenticate from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers +from rest_framework.compat import authenticate class AuthTokenSerializer(serializers.Serializer): @@ -17,7 +17,8 @@ class AuthTokenSerializer(serializers.Serializer): password = attrs.get('password') if username and password: - user = authenticate(username=username, password=password) + user = authenticate(request=self.context.get('request'), + username=username, password=password) if user: # From Django 1.10 onwards the `authenticate` call simply diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 0c6596a7e..6254d2f7f 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -13,7 +13,8 @@ class ObtainAuthToken(APIView): serializer_class = AuthTokenSerializer def post(self, request, *args, **kwargs): - serializer = self.serializer_class(data=request.data) + serializer = self.serializer_class(data=request.data, + context={'request': request}) serializer.is_valid(raise_exception=True) user = serializer.validated_data['user'] token, created = Token.objects.get_or_create(user=user) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 168bccf83..0666af322 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -354,3 +354,11 @@ def include(module, namespace=None, app_name=None): return include(module, namespace, app_name) else: return include((module, app_name), namespace) + + +def authenticate(request=None, **credentials): + from django.contrib.auth import authenticate + if django.VERSION < (1, 11): + return authenticate(**credentials) + else: + return authenticate(request=request, **credentials)