Merge pull request #171 from jazzband/refresh

Add support for Refresh Token Cookie
This commit is contained in:
Michael 2020-11-17 16:36:49 -06:00 committed by GitHub
commit 94e3805535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 38 deletions

1
.gitignore vendored
View File

@ -37,6 +37,7 @@ pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
htmlcov/ htmlcov/
coverage_html/
.tox/ .tox/
.coverage .coverage
.coverage.* .coverage.*

View File

@ -1,4 +1,5 @@
from dj_rest_auth.serializers import JWTSerializer as DefaultJWTSerializer from dj_rest_auth.serializers import JWTSerializer as DefaultJWTSerializer
from dj_rest_auth.serializers import JWTSerializerWithExpiration as DefaultJWTSerializerWithExpiration
from dj_rest_auth.serializers import LoginSerializer as DefaultLoginSerializer from dj_rest_auth.serializers import LoginSerializer as DefaultLoginSerializer
from dj_rest_auth.serializers import \ from dj_rest_auth.serializers import \
PasswordChangeSerializer as DefaultPasswordChangeSerializer PasswordChangeSerializer as DefaultPasswordChangeSerializer
@ -21,6 +22,8 @@ TokenSerializer = import_callable(serializers.get('TOKEN_SERIALIZER', DefaultTok
JWTSerializer = import_callable(serializers.get('JWT_SERIALIZER', DefaultJWTSerializer)) JWTSerializer = import_callable(serializers.get('JWT_SERIALIZER', DefaultJWTSerializer))
JWTSerializerWithExpiration = import_callable(serializers.get('JWT_SERIALIZER_WITH_EXPIRATION', DefaultJWTSerializerWithExpiration))
UserDetailsSerializer = import_callable(serializers.get('USER_DETAILS_SERIALIZER', DefaultUserDetailsSerializer)) UserDetailsSerializer = import_callable(serializers.get('USER_DETAILS_SERIALIZER', DefaultUserDetailsSerializer))
LoginSerializer = import_callable(serializers.get('LOGIN_SERIALIZER', DefaultLoginSerializer)) LoginSerializer = import_callable(serializers.get('LOGIN_SERIALIZER', DefaultLoginSerializer))
@ -38,3 +41,4 @@ PasswordChangeSerializer = import_callable(
) )
JWT_AUTH_COOKIE = getattr(settings, 'JWT_AUTH_COOKIE', None) JWT_AUTH_COOKIE = getattr(settings, 'JWT_AUTH_COOKIE', None)
JWT_AUTH_REFRESH_COOKIE = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)

View File

@ -178,6 +178,14 @@ class JWTSerializer(serializers.Serializer):
return user_data return user_data
class JWTSerializerWithExpiration(JWTSerializer):
"""
Serializer for JWT authentication with expiration times.
"""
access_token_expiration = serializers.DateTimeField()
refresh_token_expiration = serializers.DateTimeField()
class PasswordResetSerializer(serializers.Serializer): class PasswordResetSerializer(serializers.Serializer):
""" """
Serializer for requesting a password reset e-mail. Serializer for requesting a password reset e-mail.

View File

@ -11,6 +11,7 @@ from rest_framework.test import APIRequestFactory
from dj_rest_auth.registration.app_settings import register_permission_classes from dj_rest_auth.registration.app_settings import register_permission_classes
from dj_rest_auth.registration.views import RegisterView from dj_rest_auth.registration.views import RegisterView
from .mixins import CustomPermissionClass, TestsMixin from .mixins import CustomPermissionClass, TestsMixin
try: try:
@ -18,8 +19,9 @@ try:
except ImportError: except ImportError:
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from jwt import decode as decode_jwt from jwt import decode as decode_jwt
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): class TESTTokenObtainPairSerializer(TokenObtainPairSerializer):
@classmethod @classmethod
@ -71,8 +73,8 @@ class APIBasicTests(TestsMixin, TestCase):
def _generate_uid_and_token(self, user): def _generate_uid_and_token(self, user):
result = {} result = {}
from django.utils.encoding import force_bytes
from django.contrib.auth.tokens import default_token_generator from django.contrib.auth.tokens import default_token_generator
from django.utils.encoding import force_bytes
from django.utils.http import urlsafe_base64_encode from django.utils.http import urlsafe_base64_encode
result['uid'] = urlsafe_base64_encode(force_bytes(user.pk)) result['uid'] = urlsafe_base64_encode(force_bytes(user.pk))
@ -559,6 +561,20 @@ class APIBasicTests(TestsMixin, TestCase):
resp = self.post(self.logout_url, status=200) resp = self.post(self.logout_url, status=200)
self.assertEqual('', resp.cookies.get('jwt-auth').value) self.assertEqual('', resp.cookies.get('jwt-auth').value)
@override_settings(JWT_AUTH_REFRESH_COOKIE='jwt-auth-refresh')
@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
def test_logout_jwt_deletes_cookie_refresh(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
self.post(self.login_url, data=payload, status_code=200)
resp = self.post(self.logout_url, status=200)
self.assertEqual('', resp.cookies.get('jwt-auth').value)
self.assertEqual('', resp.cookies.get('jwt-auth-refresh').value)
@override_settings(REST_USE_JWT=True) @override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='jwt-auth') @override_settings(JWT_AUTH_COOKIE='jwt-auth')
@override_settings(REST_FRAMEWORK=dict( @override_settings(REST_FRAMEWORK=dict(
@ -604,21 +620,15 @@ class APIBasicTests(TestsMixin, TestCase):
resp = self.post(self.login_url, data=payload, status_code=200) resp = self.post(self.login_url, data=payload, status_code=200)
token = resp.data['refresh_token'] token = resp.data['refresh_token']
# test refresh token not included in request data # test refresh token not included in request data
resp = self.post(self.logout_url, status=200) self.post(self.logout_url, status_code=401)
self.assertEqual(resp.status_code, 401)
# test token is invalid or expired # test token is invalid or expired
resp = self.post(self.logout_url, status=200, data={'refresh': '1'}) self.post(self.logout_url, status_code=401, data={'refresh': '1'})
self.assertEqual(resp.status_code, 401)
# test successful logout # test successful logout
resp = self.post(self.logout_url, status=200, data={'refresh': token}) self.post(self.logout_url, status_code=200, data={'refresh': token})
self.assertEqual(resp.status_code, 200)
# test token is blacklisted # test token is blacklisted
resp = self.post(self.logout_url, status=200, data={'refresh': token}) self.post(self.logout_url, status_code=401, data={'refresh': token})
self.assertEqual(resp.status_code, 401)
# test other TokenError, AttributeError, TypeError (invalid format) # test other TokenError, AttributeError, TypeError (invalid format)
resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token})) self.post(self.logout_url, status_code=500, data=json.dumps({'refresh': token}))
self.assertEqual(resp.status_code, 500)
@override_settings(REST_USE_JWT=True) @override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE=None) @override_settings(JWT_AUTH_COOKIE=None)
@ -868,3 +878,38 @@ class APIBasicTests(TestsMixin, TestCase):
resp = client.post('/protected-view/', csrfparam) resp = client.post('/protected-view/', csrfparam)
self.assertEquals(resp.status_code, 200) self.assertEquals(resp.status_code, 200)
@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
@override_settings(REST_USE_JWT=True)
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
def test_return_expiration(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}
# create user
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
resp = self.post(self.login_url, data=payload, status_code=200)
self.assertIn('access_token_expiration', resp.data.keys())
self.assertIn('refresh_token_expiration', resp.data.keys())
@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
@override_settings(REST_USE_JWT=True)
@override_settings(JWT_AUTH_COOKIE='xxx')
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
@override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx')
@override_settings(JWT_AUTH_REFRESH_COOKIE_PATH='/foo/bar')
def test_refresh_cookie_name(self):
payload = {
"username": self.USERNAME,
"password": self.PASS
}
# create user
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
resp = self.post(self.login_url, data=payload, status_code=200)
self.assertIn('xxx', resp.cookies.keys())
self.assertIn('refresh-xxx', resp.cookies.keys())
self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar')

View File

@ -2,6 +2,7 @@ from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth import login as django_login from django.contrib.auth import login as django_login
from django.contrib.auth import logout as django_logout from django.contrib.auth import logout as django_logout
from django.utils import timezone
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -12,7 +13,7 @@ from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from .app_settings import (JWTSerializer, LoginSerializer, from .app_settings import (JWTSerializer, JWTSerializerWithExpiration, LoginSerializer,
PasswordChangeSerializer, PasswordChangeSerializer,
PasswordResetConfirmSerializer, PasswordResetConfirmSerializer,
PasswordResetSerializer, TokenSerializer, PasswordResetSerializer, TokenSerializer,
@ -51,7 +52,12 @@ class LoginView(GenericAPIView):
def get_response_serializer(self): def get_response_serializer(self):
if getattr(settings, 'REST_USE_JWT', False): if getattr(settings, 'REST_USE_JWT', False):
if getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False):
response_serializer = JWTSerializerWithExpiration
else:
response_serializer = JWTSerializer response_serializer = JWTSerializer
else: else:
response_serializer = TokenSerializer response_serializer = TokenSerializer
return response_serializer return response_serializer
@ -71,12 +77,24 @@ class LoginView(GenericAPIView):
def get_response(self): def get_response(self):
serializer_class = self.get_response_serializer() serializer_class = self.get_response_serializer()
access_token_expiration = None
refresh_token_expiration = None
if getattr(settings, 'REST_USE_JWT', False): if getattr(settings, 'REST_USE_JWT', False):
from rest_framework_simplejwt.settings import api_settings as jwt_settings
access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME)
return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False)
data = { data = {
'user': self.user, 'user': self.user,
'access_token': self.access_token, 'access_token': self.access_token,
'refresh_token': self.refresh_token 'refresh_token': self.refresh_token
} }
if return_expiration_times:
data['access_token_expiration'] = access_token_expiration
data['refresh_token_expiration'] = refresh_token_expiration
serializer = serializer_class(instance=data, serializer = serializer_class(instance=data,
context=self.get_serializer_context()) context=self.get_serializer_context())
else: else:
@ -86,21 +104,32 @@ class LoginView(GenericAPIView):
response = Response(serializer.data, status=status.HTTP_200_OK) response = Response(serializer.data, status=status.HTTP_200_OK)
if getattr(settings, 'REST_USE_JWT', False): if getattr(settings, 'REST_USE_JWT', False):
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False) cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True) cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax') cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')
from rest_framework_simplejwt.settings import api_settings as jwt_settings
if cookie_name: if cookie_name:
from datetime import datetime
expiration = (datetime.utcnow() + jwt_settings.ACCESS_TOKEN_LIFETIME)
response.set_cookie( response.set_cookie(
cookie_name, cookie_name,
self.access_token, self.access_token,
expires=expiration, expires=access_token_expiration,
secure=cookie_secure, secure=cookie_secure,
httponly=cookie_httponly, httponly=cookie_httponly,
samesite=cookie_samesite samesite=cookie_samesite
) )
if refresh_cookie_name:
response.set_cookie(
refresh_cookie_name,
self.refresh_token,
expires=refresh_token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite,
path=refresh_cookie_path
)
return response return response
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@ -142,8 +171,10 @@ class LogoutView(APIView):
if getattr(settings, 'REST_SESSION_LOGIN', True): if getattr(settings, 'REST_SESSION_LOGIN', True):
django_logout(request) django_logout(request)
response = Response({"detail": _("Successfully logged out.")}, response = Response(
status=status.HTTP_200_OK) {"detail": _("Successfully logged out.")},
status=status.HTTP_200_OK
)
if getattr(settings, 'REST_USE_JWT', False): if getattr(settings, 'REST_USE_JWT', False):
# NOTE: this import occurs here rather than at the top level # NOTE: this import occurs here rather than at the top level
@ -155,37 +186,38 @@ class LogoutView(APIView):
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
if cookie_name: if cookie_name:
response.delete_cookie(cookie_name) response.delete_cookie(cookie_name)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
if refresh_cookie_name:
response.delete_cookie(refresh_cookie_name)
elif 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS: if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
# add refresh token to blacklist # add refresh token to blacklist
try: try:
token = RefreshToken(request.data['refresh']) token = RefreshToken(request.data['refresh'])
token.blacklist() token.blacklist()
except KeyError: except KeyError:
response = Response({"detail": _("Refresh token was not included in request data.")}, response.data = {"detail": _("Refresh token was not included in request data.")}
status=status.HTTP_401_UNAUTHORIZED) response.status_code =status.HTTP_401_UNAUTHORIZED
except (TokenError, AttributeError, TypeError) as error: except (TokenError, AttributeError, TypeError) as error:
if hasattr(error, 'args'): if hasattr(error, 'args'):
if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args: if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args:
response = Response({"detail": _(error.args[0])}, response.data = {"detail": _(error.args[0])}
status=status.HTTP_401_UNAUTHORIZED) response.status_code = status.HTTP_401_UNAUTHORIZED
else:
response.data = {"detail": _("An error has occurred.")}
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
else: else:
response = Response({"detail": _("An error has occurred.")}, response.data = {"detail": _("An error has occurred.")}
status=status.HTTP_500_INTERNAL_SERVER_ERROR) response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
else: else:
response = Response({"detail": _("An error has occurred.")}, message = _(
status=status.HTTP_500_INTERNAL_SERVER_ERROR) "Neither cookies or blacklist are enabled, so the token "
"has not been deleted server side. Please make sure the token is deleted client side."
else: )
response = Response({ response.data = {"detail": message}
"detail": _("Neither cookies or blacklist are enabled, so the token has not been deleted server " response.status_code = status.HTTP_200_OK
"side. Please make sure the token is deleted client side."
)}, status=status.HTTP_200_OK)
return response return response