mirror of
https://github.com/Tivix/django-rest-auth.git
synced 2025-07-15 10:22:18 +03:00
Merge pull request #171 from jazzband/refresh
Add support for Refresh Token Cookie
This commit is contained in:
commit
94e3805535
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -37,6 +37,7 @@ pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
htmlcov/
|
htmlcov/
|
||||||
|
coverage_html/
|
||||||
.tox/
|
.tox/
|
||||||
.coverage
|
.coverage
|
||||||
.coverage.*
|
.coverage.*
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from dj_rest_auth.serializers import JWTSerializer as DefaultJWTSerializer
|
from dj_rest_auth.serializers import JWTSerializer as DefaultJWTSerializer
|
||||||
|
from dj_rest_auth.serializers import JWTSerializerWithExpiration as DefaultJWTSerializerWithExpiration
|
||||||
from dj_rest_auth.serializers import LoginSerializer as DefaultLoginSerializer
|
from dj_rest_auth.serializers import LoginSerializer as DefaultLoginSerializer
|
||||||
from dj_rest_auth.serializers import \
|
from dj_rest_auth.serializers import \
|
||||||
PasswordChangeSerializer as DefaultPasswordChangeSerializer
|
PasswordChangeSerializer as DefaultPasswordChangeSerializer
|
||||||
|
@ -21,6 +22,8 @@ TokenSerializer = import_callable(serializers.get('TOKEN_SERIALIZER', DefaultTok
|
||||||
|
|
||||||
JWTSerializer = import_callable(serializers.get('JWT_SERIALIZER', DefaultJWTSerializer))
|
JWTSerializer = import_callable(serializers.get('JWT_SERIALIZER', DefaultJWTSerializer))
|
||||||
|
|
||||||
|
JWTSerializerWithExpiration = import_callable(serializers.get('JWT_SERIALIZER_WITH_EXPIRATION', DefaultJWTSerializerWithExpiration))
|
||||||
|
|
||||||
UserDetailsSerializer = import_callable(serializers.get('USER_DETAILS_SERIALIZER', DefaultUserDetailsSerializer))
|
UserDetailsSerializer = import_callable(serializers.get('USER_DETAILS_SERIALIZER', DefaultUserDetailsSerializer))
|
||||||
|
|
||||||
LoginSerializer = import_callable(serializers.get('LOGIN_SERIALIZER', DefaultLoginSerializer))
|
LoginSerializer = import_callable(serializers.get('LOGIN_SERIALIZER', DefaultLoginSerializer))
|
||||||
|
@ -38,3 +41,4 @@ PasswordChangeSerializer = import_callable(
|
||||||
)
|
)
|
||||||
|
|
||||||
JWT_AUTH_COOKIE = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
JWT_AUTH_COOKIE = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
||||||
|
JWT_AUTH_REFRESH_COOKIE = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
|
||||||
|
|
|
@ -178,6 +178,14 @@ class JWTSerializer(serializers.Serializer):
|
||||||
return user_data
|
return user_data
|
||||||
|
|
||||||
|
|
||||||
|
class JWTSerializerWithExpiration(JWTSerializer):
|
||||||
|
"""
|
||||||
|
Serializer for JWT authentication with expiration times.
|
||||||
|
"""
|
||||||
|
access_token_expiration = serializers.DateTimeField()
|
||||||
|
refresh_token_expiration = serializers.DateTimeField()
|
||||||
|
|
||||||
|
|
||||||
class PasswordResetSerializer(serializers.Serializer):
|
class PasswordResetSerializer(serializers.Serializer):
|
||||||
"""
|
"""
|
||||||
Serializer for requesting a password reset e-mail.
|
Serializer for requesting a password reset e-mail.
|
||||||
|
|
|
@ -11,6 +11,7 @@ from rest_framework.test import APIRequestFactory
|
||||||
|
|
||||||
from dj_rest_auth.registration.app_settings import register_permission_classes
|
from dj_rest_auth.registration.app_settings import register_permission_classes
|
||||||
from dj_rest_auth.registration.views import RegisterView
|
from dj_rest_auth.registration.views import RegisterView
|
||||||
|
|
||||||
from .mixins import CustomPermissionClass, TestsMixin
|
from .mixins import CustomPermissionClass, TestsMixin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -18,8 +19,9 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
|
|
||||||
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
|
||||||
from jwt import decode as decode_jwt
|
from jwt import decode as decode_jwt
|
||||||
|
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
||||||
|
|
||||||
|
|
||||||
class TESTTokenObtainPairSerializer(TokenObtainPairSerializer):
|
class TESTTokenObtainPairSerializer(TokenObtainPairSerializer):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -71,8 +73,8 @@ class APIBasicTests(TestsMixin, TestCase):
|
||||||
|
|
||||||
def _generate_uid_and_token(self, user):
|
def _generate_uid_and_token(self, user):
|
||||||
result = {}
|
result = {}
|
||||||
from django.utils.encoding import force_bytes
|
|
||||||
from django.contrib.auth.tokens import default_token_generator
|
from django.contrib.auth.tokens import default_token_generator
|
||||||
|
from django.utils.encoding import force_bytes
|
||||||
from django.utils.http import urlsafe_base64_encode
|
from django.utils.http import urlsafe_base64_encode
|
||||||
|
|
||||||
result['uid'] = urlsafe_base64_encode(force_bytes(user.pk))
|
result['uid'] = urlsafe_base64_encode(force_bytes(user.pk))
|
||||||
|
@ -559,6 +561,20 @@ class APIBasicTests(TestsMixin, TestCase):
|
||||||
resp = self.post(self.logout_url, status=200)
|
resp = self.post(self.logout_url, status=200)
|
||||||
self.assertEqual('', resp.cookies.get('jwt-auth').value)
|
self.assertEqual('', resp.cookies.get('jwt-auth').value)
|
||||||
|
|
||||||
|
@override_settings(JWT_AUTH_REFRESH_COOKIE='jwt-auth-refresh')
|
||||||
|
@override_settings(REST_USE_JWT=True)
|
||||||
|
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
|
||||||
|
def test_logout_jwt_deletes_cookie_refresh(self):
|
||||||
|
payload = {
|
||||||
|
"username": self.USERNAME,
|
||||||
|
"password": self.PASS
|
||||||
|
}
|
||||||
|
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
|
||||||
|
self.post(self.login_url, data=payload, status_code=200)
|
||||||
|
resp = self.post(self.logout_url, status=200)
|
||||||
|
self.assertEqual('', resp.cookies.get('jwt-auth').value)
|
||||||
|
self.assertEqual('', resp.cookies.get('jwt-auth-refresh').value)
|
||||||
|
|
||||||
@override_settings(REST_USE_JWT=True)
|
@override_settings(REST_USE_JWT=True)
|
||||||
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
|
@override_settings(JWT_AUTH_COOKIE='jwt-auth')
|
||||||
@override_settings(REST_FRAMEWORK=dict(
|
@override_settings(REST_FRAMEWORK=dict(
|
||||||
|
@ -604,21 +620,15 @@ class APIBasicTests(TestsMixin, TestCase):
|
||||||
resp = self.post(self.login_url, data=payload, status_code=200)
|
resp = self.post(self.login_url, data=payload, status_code=200)
|
||||||
token = resp.data['refresh_token']
|
token = resp.data['refresh_token']
|
||||||
# test refresh token not included in request data
|
# test refresh token not included in request data
|
||||||
resp = self.post(self.logout_url, status=200)
|
self.post(self.logout_url, status_code=401)
|
||||||
self.assertEqual(resp.status_code, 401)
|
|
||||||
# test token is invalid or expired
|
# test token is invalid or expired
|
||||||
resp = self.post(self.logout_url, status=200, data={'refresh': '1'})
|
self.post(self.logout_url, status_code=401, data={'refresh': '1'})
|
||||||
self.assertEqual(resp.status_code, 401)
|
|
||||||
# test successful logout
|
# test successful logout
|
||||||
resp = self.post(self.logout_url, status=200, data={'refresh': token})
|
self.post(self.logout_url, status_code=200, data={'refresh': token})
|
||||||
self.assertEqual(resp.status_code, 200)
|
|
||||||
# test token is blacklisted
|
# test token is blacklisted
|
||||||
resp = self.post(self.logout_url, status=200, data={'refresh': token})
|
self.post(self.logout_url, status_code=401, data={'refresh': token})
|
||||||
self.assertEqual(resp.status_code, 401)
|
|
||||||
# test other TokenError, AttributeError, TypeError (invalid format)
|
# test other TokenError, AttributeError, TypeError (invalid format)
|
||||||
resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token}))
|
self.post(self.logout_url, status_code=500, data=json.dumps({'refresh': token}))
|
||||||
self.assertEqual(resp.status_code, 500)
|
|
||||||
|
|
||||||
|
|
||||||
@override_settings(REST_USE_JWT=True)
|
@override_settings(REST_USE_JWT=True)
|
||||||
@override_settings(JWT_AUTH_COOKIE=None)
|
@override_settings(JWT_AUTH_COOKIE=None)
|
||||||
|
@ -868,3 +878,38 @@ class APIBasicTests(TestsMixin, TestCase):
|
||||||
resp = client.post('/protected-view/', csrfparam)
|
resp = client.post('/protected-view/', csrfparam)
|
||||||
self.assertEquals(resp.status_code, 200)
|
self.assertEquals(resp.status_code, 200)
|
||||||
|
|
||||||
|
@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
|
||||||
|
@override_settings(REST_USE_JWT=True)
|
||||||
|
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
|
||||||
|
def test_return_expiration(self):
|
||||||
|
payload = {
|
||||||
|
"username": self.USERNAME,
|
||||||
|
"password": self.PASS
|
||||||
|
}
|
||||||
|
|
||||||
|
# create user
|
||||||
|
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
|
||||||
|
|
||||||
|
resp = self.post(self.login_url, data=payload, status_code=200)
|
||||||
|
self.assertIn('access_token_expiration', resp.data.keys())
|
||||||
|
self.assertIn('refresh_token_expiration', resp.data.keys())
|
||||||
|
|
||||||
|
@override_settings(JWT_AUTH_RETURN_EXPIRATION=True)
|
||||||
|
@override_settings(REST_USE_JWT=True)
|
||||||
|
@override_settings(JWT_AUTH_COOKIE='xxx')
|
||||||
|
@override_settings(ACCOUNT_LOGOUT_ON_GET=True)
|
||||||
|
@override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx')
|
||||||
|
@override_settings(JWT_AUTH_REFRESH_COOKIE_PATH='/foo/bar')
|
||||||
|
def test_refresh_cookie_name(self):
|
||||||
|
payload = {
|
||||||
|
"username": self.USERNAME,
|
||||||
|
"password": self.PASS
|
||||||
|
}
|
||||||
|
|
||||||
|
# create user
|
||||||
|
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
|
||||||
|
|
||||||
|
resp = self.post(self.login_url, data=payload, status_code=200)
|
||||||
|
self.assertIn('xxx', resp.cookies.keys())
|
||||||
|
self.assertIn('refresh-xxx', resp.cookies.keys())
|
||||||
|
self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar')
|
||||||
|
|
|
@ -2,6 +2,7 @@ from django.conf import settings
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.contrib.auth import login as django_login
|
from django.contrib.auth import login as django_login
|
||||||
from django.contrib.auth import logout as django_logout
|
from django.contrib.auth import logout as django_logout
|
||||||
|
from django.utils import timezone
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
@ -12,7 +13,7 @@ from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
from .app_settings import (JWTSerializer, LoginSerializer,
|
from .app_settings import (JWTSerializer, JWTSerializerWithExpiration, LoginSerializer,
|
||||||
PasswordChangeSerializer,
|
PasswordChangeSerializer,
|
||||||
PasswordResetConfirmSerializer,
|
PasswordResetConfirmSerializer,
|
||||||
PasswordResetSerializer, TokenSerializer,
|
PasswordResetSerializer, TokenSerializer,
|
||||||
|
@ -51,7 +52,12 @@ class LoginView(GenericAPIView):
|
||||||
|
|
||||||
def get_response_serializer(self):
|
def get_response_serializer(self):
|
||||||
if getattr(settings, 'REST_USE_JWT', False):
|
if getattr(settings, 'REST_USE_JWT', False):
|
||||||
response_serializer = JWTSerializer
|
|
||||||
|
if getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False):
|
||||||
|
response_serializer = JWTSerializerWithExpiration
|
||||||
|
else:
|
||||||
|
response_serializer = JWTSerializer
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response_serializer = TokenSerializer
|
response_serializer = TokenSerializer
|
||||||
return response_serializer
|
return response_serializer
|
||||||
|
@ -71,12 +77,24 @@ class LoginView(GenericAPIView):
|
||||||
def get_response(self):
|
def get_response(self):
|
||||||
serializer_class = self.get_response_serializer()
|
serializer_class = self.get_response_serializer()
|
||||||
|
|
||||||
|
access_token_expiration = None
|
||||||
|
refresh_token_expiration = None
|
||||||
if getattr(settings, 'REST_USE_JWT', False):
|
if getattr(settings, 'REST_USE_JWT', False):
|
||||||
|
from rest_framework_simplejwt.settings import api_settings as jwt_settings
|
||||||
|
access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
|
||||||
|
refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME)
|
||||||
|
return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'user': self.user,
|
'user': self.user,
|
||||||
'access_token': self.access_token,
|
'access_token': self.access_token,
|
||||||
'refresh_token': self.refresh_token
|
'refresh_token': self.refresh_token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if return_expiration_times:
|
||||||
|
data['access_token_expiration'] = access_token_expiration
|
||||||
|
data['refresh_token_expiration'] = refresh_token_expiration
|
||||||
|
|
||||||
serializer = serializer_class(instance=data,
|
serializer = serializer_class(instance=data,
|
||||||
context=self.get_serializer_context())
|
context=self.get_serializer_context())
|
||||||
else:
|
else:
|
||||||
|
@ -86,21 +104,32 @@ class LoginView(GenericAPIView):
|
||||||
response = Response(serializer.data, status=status.HTTP_200_OK)
|
response = Response(serializer.data, status=status.HTTP_200_OK)
|
||||||
if getattr(settings, 'REST_USE_JWT', False):
|
if getattr(settings, 'REST_USE_JWT', False):
|
||||||
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
||||||
|
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
|
||||||
|
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')
|
||||||
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
|
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
|
||||||
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
|
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
|
||||||
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')
|
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')
|
||||||
from rest_framework_simplejwt.settings import api_settings as jwt_settings
|
|
||||||
if cookie_name:
|
if cookie_name:
|
||||||
from datetime import datetime
|
|
||||||
expiration = (datetime.utcnow() + jwt_settings.ACCESS_TOKEN_LIFETIME)
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
cookie_name,
|
cookie_name,
|
||||||
self.access_token,
|
self.access_token,
|
||||||
expires=expiration,
|
expires=access_token_expiration,
|
||||||
secure=cookie_secure,
|
secure=cookie_secure,
|
||||||
httponly=cookie_httponly,
|
httponly=cookie_httponly,
|
||||||
samesite=cookie_samesite
|
samesite=cookie_samesite
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if refresh_cookie_name:
|
||||||
|
response.set_cookie(
|
||||||
|
refresh_cookie_name,
|
||||||
|
self.refresh_token,
|
||||||
|
expires=refresh_token_expiration,
|
||||||
|
secure=cookie_secure,
|
||||||
|
httponly=cookie_httponly,
|
||||||
|
samesite=cookie_samesite,
|
||||||
|
path=refresh_cookie_path
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
|
@ -142,8 +171,10 @@ class LogoutView(APIView):
|
||||||
if getattr(settings, 'REST_SESSION_LOGIN', True):
|
if getattr(settings, 'REST_SESSION_LOGIN', True):
|
||||||
django_logout(request)
|
django_logout(request)
|
||||||
|
|
||||||
response = Response({"detail": _("Successfully logged out.")},
|
response = Response(
|
||||||
status=status.HTTP_200_OK)
|
{"detail": _("Successfully logged out.")},
|
||||||
|
status=status.HTTP_200_OK
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(settings, 'REST_USE_JWT', False):
|
if getattr(settings, 'REST_USE_JWT', False):
|
||||||
# NOTE: this import occurs here rather than at the top level
|
# NOTE: this import occurs here rather than at the top level
|
||||||
|
@ -155,37 +186,38 @@ class LogoutView(APIView):
|
||||||
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
|
||||||
if cookie_name:
|
if cookie_name:
|
||||||
response.delete_cookie(cookie_name)
|
response.delete_cookie(cookie_name)
|
||||||
|
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
|
||||||
|
if refresh_cookie_name:
|
||||||
|
response.delete_cookie(refresh_cookie_name)
|
||||||
|
|
||||||
elif 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
|
if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
|
||||||
# add refresh token to blacklist
|
# add refresh token to blacklist
|
||||||
try:
|
try:
|
||||||
token = RefreshToken(request.data['refresh'])
|
token = RefreshToken(request.data['refresh'])
|
||||||
token.blacklist()
|
token.blacklist()
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
response = Response({"detail": _("Refresh token was not included in request data.")},
|
response.data = {"detail": _("Refresh token was not included in request data.")}
|
||||||
status=status.HTTP_401_UNAUTHORIZED)
|
response.status_code =status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
except (TokenError, AttributeError, TypeError) as error:
|
except (TokenError, AttributeError, TypeError) as error:
|
||||||
if hasattr(error, 'args'):
|
if hasattr(error, 'args'):
|
||||||
if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args:
|
if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args:
|
||||||
response = Response({"detail": _(error.args[0])},
|
response.data = {"detail": _(error.args[0])}
|
||||||
status=status.HTTP_401_UNAUTHORIZED)
|
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = Response({"detail": _("An error has occurred.")},
|
response.data = {"detail": _("An error has occurred.")}
|
||||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = Response({"detail": _("An error has occurred.")},
|
response.data = {"detail": _("An error has occurred.")}
|
||||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = Response({
|
message = _(
|
||||||
"detail": _("Neither cookies or blacklist are enabled, so the token has not been deleted server "
|
"Neither cookies or blacklist are enabled, so the token "
|
||||||
"side. Please make sure the token is deleted client side."
|
"has not been deleted server side. Please make sure the token is deleted client side."
|
||||||
)}, status=status.HTTP_200_OK)
|
)
|
||||||
|
response.data = {"detail": message}
|
||||||
|
response.status_code = status.HTTP_200_OK
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user