diff --git a/.gitignore b/.gitignore index 266f6ad..eb18197 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ +coverage_html/ .tox/ .coverage .coverage.* diff --git a/dj_rest_auth/app_settings.py b/dj_rest_auth/app_settings.py index fe6e6a2..543b2be 100644 --- a/dj_rest_auth/app_settings.py +++ b/dj_rest_auth/app_settings.py @@ -1,4 +1,5 @@ from dj_rest_auth.serializers import JWTSerializer as DefaultJWTSerializer +from dj_rest_auth.serializers import JWTSerializerWithExpiration as DefaultJWTSerializerWithExpiration from dj_rest_auth.serializers import LoginSerializer as DefaultLoginSerializer from dj_rest_auth.serializers import \ PasswordChangeSerializer as DefaultPasswordChangeSerializer @@ -21,6 +22,8 @@ TokenSerializer = import_callable(serializers.get('TOKEN_SERIALIZER', DefaultTok JWTSerializer = import_callable(serializers.get('JWT_SERIALIZER', DefaultJWTSerializer)) +JWTSerializerWithExpiration = import_callable(serializers.get('JWT_SERIALIZER_WITH_EXPIRATION', DefaultJWTSerializerWithExpiration)) + UserDetailsSerializer = import_callable(serializers.get('USER_DETAILS_SERIALIZER', DefaultUserDetailsSerializer)) LoginSerializer = import_callable(serializers.get('LOGIN_SERIALIZER', DefaultLoginSerializer)) @@ -38,3 +41,4 @@ PasswordChangeSerializer = import_callable( ) JWT_AUTH_COOKIE = getattr(settings, 'JWT_AUTH_COOKIE', None) +JWT_AUTH_REFRESH_COOKIE = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None) diff --git a/dj_rest_auth/serializers.py b/dj_rest_auth/serializers.py index f3a9614..2b879aa 100644 --- a/dj_rest_auth/serializers.py +++ b/dj_rest_auth/serializers.py @@ -178,6 +178,14 @@ class JWTSerializer(serializers.Serializer): return user_data +class JWTSerializerWithExpiration(JWTSerializer): + """ + Serializer for JWT authentication with expiration times. + """ + access_token_expiration = serializers.DateTimeField() + refresh_token_expiration = serializers.DateTimeField() + + class PasswordResetSerializer(serializers.Serializer): """ Serializer for requesting a password reset e-mail. diff --git a/dj_rest_auth/tests/test_api.py b/dj_rest_auth/tests/test_api.py index 0f8ee42..64a7f1b 100644 --- a/dj_rest_auth/tests/test_api.py +++ b/dj_rest_auth/tests/test_api.py @@ -11,6 +11,7 @@ from rest_framework.test import APIRequestFactory from dj_rest_auth.registration.app_settings import register_permission_classes from dj_rest_auth.registration.views import RegisterView + from .mixins import CustomPermissionClass, TestsMixin try: @@ -18,8 +19,9 @@ try: except ImportError: from django.core.urlresolvers import reverse -from rest_framework_simplejwt.serializers import TokenObtainPairSerializer from jwt import decode as decode_jwt +from rest_framework_simplejwt.serializers import TokenObtainPairSerializer + class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): @classmethod @@ -71,8 +73,8 @@ class APIBasicTests(TestsMixin, TestCase): def _generate_uid_and_token(self, user): result = {} - from django.utils.encoding import force_bytes from django.contrib.auth.tokens import default_token_generator + from django.utils.encoding import force_bytes from django.utils.http import urlsafe_base64_encode result['uid'] = urlsafe_base64_encode(force_bytes(user.pk)) @@ -559,6 +561,20 @@ class APIBasicTests(TestsMixin, TestCase): resp = self.post(self.logout_url, status=200) self.assertEqual('', resp.cookies.get('jwt-auth').value) + @override_settings(JWT_AUTH_REFRESH_COOKIE='jwt-auth-refresh') + @override_settings(REST_USE_JWT=True) + @override_settings(JWT_AUTH_COOKIE='jwt-auth') + def test_logout_jwt_deletes_cookie_refresh(self): + payload = { + "username": self.USERNAME, + "password": self.PASS + } + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + self.post(self.login_url, data=payload, status_code=200) + resp = self.post(self.logout_url, status=200) + self.assertEqual('', resp.cookies.get('jwt-auth').value) + self.assertEqual('', resp.cookies.get('jwt-auth-refresh').value) + @override_settings(REST_USE_JWT=True) @override_settings(JWT_AUTH_COOKIE='jwt-auth') @override_settings(REST_FRAMEWORK=dict( @@ -604,21 +620,15 @@ class APIBasicTests(TestsMixin, TestCase): resp = self.post(self.login_url, data=payload, status_code=200) token = resp.data['refresh_token'] # test refresh token not included in request data - resp = self.post(self.logout_url, status=200) - self.assertEqual(resp.status_code, 401) + self.post(self.logout_url, status_code=401) # test token is invalid or expired - resp = self.post(self.logout_url, status=200, data={'refresh': '1'}) - self.assertEqual(resp.status_code, 401) + self.post(self.logout_url, status_code=401, data={'refresh': '1'}) # test successful logout - resp = self.post(self.logout_url, status=200, data={'refresh': token}) - self.assertEqual(resp.status_code, 200) + self.post(self.logout_url, status_code=200, data={'refresh': token}) # test token is blacklisted - resp = self.post(self.logout_url, status=200, data={'refresh': token}) - self.assertEqual(resp.status_code, 401) + self.post(self.logout_url, status_code=401, data={'refresh': token}) # test other TokenError, AttributeError, TypeError (invalid format) - resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token})) - self.assertEqual(resp.status_code, 500) - + self.post(self.logout_url, status_code=500, data=json.dumps({'refresh': token})) @override_settings(REST_USE_JWT=True) @override_settings(JWT_AUTH_COOKIE=None) @@ -868,3 +878,38 @@ class APIBasicTests(TestsMixin, TestCase): resp = client.post('/protected-view/', csrfparam) self.assertEquals(resp.status_code, 200) + @override_settings(JWT_AUTH_RETURN_EXPIRATION=True) + @override_settings(REST_USE_JWT=True) + @override_settings(ACCOUNT_LOGOUT_ON_GET=True) + def test_return_expiration(self): + payload = { + "username": self.USERNAME, + "password": self.PASS + } + + # create user + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + + resp = self.post(self.login_url, data=payload, status_code=200) + self.assertIn('access_token_expiration', resp.data.keys()) + self.assertIn('refresh_token_expiration', resp.data.keys()) + + @override_settings(JWT_AUTH_RETURN_EXPIRATION=True) + @override_settings(REST_USE_JWT=True) + @override_settings(JWT_AUTH_COOKIE='xxx') + @override_settings(ACCOUNT_LOGOUT_ON_GET=True) + @override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx') + @override_settings(JWT_AUTH_REFRESH_COOKIE_PATH='/foo/bar') + def test_refresh_cookie_name(self): + payload = { + "username": self.USERNAME, + "password": self.PASS + } + + # create user + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + + resp = self.post(self.login_url, data=payload, status_code=200) + self.assertIn('xxx', resp.cookies.keys()) + self.assertIn('refresh-xxx', resp.cookies.keys()) + self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar') diff --git a/dj_rest_auth/views.py b/dj_rest_auth/views.py index 28b472d..d889f2a 100644 --- a/dj_rest_auth/views.py +++ b/dj_rest_auth/views.py @@ -2,6 +2,7 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth import login as django_login from django.contrib.auth import logout as django_logout +from django.utils import timezone from django.core.exceptions import ObjectDoesNotExist from django.utils.decorators import method_decorator from django.utils.translation import ugettext_lazy as _ @@ -12,7 +13,7 @@ from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView -from .app_settings import (JWTSerializer, LoginSerializer, +from .app_settings import (JWTSerializer, JWTSerializerWithExpiration, LoginSerializer, PasswordChangeSerializer, PasswordResetConfirmSerializer, PasswordResetSerializer, TokenSerializer, @@ -51,7 +52,12 @@ class LoginView(GenericAPIView): def get_response_serializer(self): if getattr(settings, 'REST_USE_JWT', False): - response_serializer = JWTSerializer + + if getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False): + response_serializer = JWTSerializerWithExpiration + else: + response_serializer = JWTSerializer + else: response_serializer = TokenSerializer return response_serializer @@ -71,12 +77,24 @@ class LoginView(GenericAPIView): def get_response(self): serializer_class = self.get_response_serializer() + access_token_expiration = None + refresh_token_expiration = None if getattr(settings, 'REST_USE_JWT', False): + from rest_framework_simplejwt.settings import api_settings as jwt_settings + access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME) + refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME) + return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False) + data = { 'user': self.user, 'access_token': self.access_token, 'refresh_token': self.refresh_token } + + if return_expiration_times: + data['access_token_expiration'] = access_token_expiration + data['refresh_token_expiration'] = refresh_token_expiration + serializer = serializer_class(instance=data, context=self.get_serializer_context()) else: @@ -86,21 +104,32 @@ class LoginView(GenericAPIView): response = Response(serializer.data, status=status.HTTP_200_OK) if getattr(settings, 'REST_USE_JWT', False): cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) + refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None) + refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/') cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False) cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True) cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax') - from rest_framework_simplejwt.settings import api_settings as jwt_settings + if cookie_name: - from datetime import datetime - expiration = (datetime.utcnow() + jwt_settings.ACCESS_TOKEN_LIFETIME) response.set_cookie( cookie_name, self.access_token, - expires=expiration, + expires=access_token_expiration, secure=cookie_secure, httponly=cookie_httponly, samesite=cookie_samesite ) + + if refresh_cookie_name: + response.set_cookie( + refresh_cookie_name, + self.refresh_token, + expires=refresh_token_expiration, + secure=cookie_secure, + httponly=cookie_httponly, + samesite=cookie_samesite, + path=refresh_cookie_path + ) return response def post(self, request, *args, **kwargs): @@ -142,8 +171,10 @@ class LogoutView(APIView): if getattr(settings, 'REST_SESSION_LOGIN', True): django_logout(request) - response = Response({"detail": _("Successfully logged out.")}, - status=status.HTTP_200_OK) + response = Response( + {"detail": _("Successfully logged out.")}, + status=status.HTTP_200_OK + ) if getattr(settings, 'REST_USE_JWT', False): # NOTE: this import occurs here rather than at the top level @@ -155,37 +186,38 @@ class LogoutView(APIView): cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) if cookie_name: response.delete_cookie(cookie_name) + refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None) + if refresh_cookie_name: + response.delete_cookie(refresh_cookie_name) - elif 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS: + if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS: # add refresh token to blacklist try: token = RefreshToken(request.data['refresh']) token.blacklist() - except KeyError: - response = Response({"detail": _("Refresh token was not included in request data.")}, - status=status.HTTP_401_UNAUTHORIZED) - + response.data = {"detail": _("Refresh token was not included in request data.")} + response.status_code =status.HTTP_401_UNAUTHORIZED except (TokenError, AttributeError, TypeError) as error: if hasattr(error, 'args'): if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args: - response = Response({"detail": _(error.args[0])}, - status=status.HTTP_401_UNAUTHORIZED) - + response.data = {"detail": _(error.args[0])} + response.status_code = status.HTTP_401_UNAUTHORIZED else: - response = Response({"detail": _("An error has occurred.")}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) + response.data = {"detail": _("An error has occurred.")} + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR else: - response = Response({"detail": _("An error has occurred.")}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) + response.data = {"detail": _("An error has occurred.")} + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR else: - response = Response({ - "detail": _("Neither cookies or blacklist are enabled, so the token has not been deleted server " - "side. Please make sure the token is deleted client side." - )}, status=status.HTTP_200_OK) - + message = _( + "Neither cookies or blacklist are enabled, so the token " + "has not been deleted server side. Please make sure the token is deleted client side." + ) + response.data = {"detail": message} + response.status_code = status.HTTP_200_OK return response