mirror of
				https://github.com/Tivix/django-rest-auth.git
				synced 2025-10-31 15:57:34 +03:00 
			
		
		
		
	Cleans up refresh logic + Adds unit tests
This commit is contained in:
		
							parent
							
								
									5d6e8ca03b
								
							
						
					
					
						commit
						63bd99ac30
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							|  | @ -37,6 +37,7 @@ pip-delete-this-directory.txt | ||||||
| 
 | 
 | ||||||
| # Unit test / coverage reports | # Unit test / coverage reports | ||||||
| htmlcov/ | htmlcov/ | ||||||
|  | coverage_html/ | ||||||
| .tox/ | .tox/ | ||||||
| .coverage | .coverage | ||||||
| .coverage.* | .coverage.* | ||||||
|  |  | ||||||
|  | @ -11,6 +11,7 @@ from rest_framework.test import APIRequestFactory | ||||||
| 
 | 
 | ||||||
| from dj_rest_auth.registration.app_settings import register_permission_classes | from dj_rest_auth.registration.app_settings import register_permission_classes | ||||||
| from dj_rest_auth.registration.views import RegisterView | from dj_rest_auth.registration.views import RegisterView | ||||||
|  | 
 | ||||||
| from .mixins import CustomPermissionClass, TestsMixin | from .mixins import CustomPermissionClass, TestsMixin | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|  | @ -18,8 +19,9 @@ try: | ||||||
| except ImportError: | except ImportError: | ||||||
|     from django.core.urlresolvers import reverse |     from django.core.urlresolvers import reverse | ||||||
| 
 | 
 | ||||||
| from rest_framework_simplejwt.serializers import TokenObtainPairSerializer |  | ||||||
| from jwt import decode as decode_jwt | from jwt import decode as decode_jwt | ||||||
|  | from rest_framework_simplejwt.serializers import TokenObtainPairSerializer | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): | class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): | ||||||
|     @classmethod |     @classmethod | ||||||
|  | @ -71,8 +73,8 @@ class APIBasicTests(TestsMixin, TestCase): | ||||||
| 
 | 
 | ||||||
|     def _generate_uid_and_token(self, user): |     def _generate_uid_and_token(self, user): | ||||||
|         result = {} |         result = {} | ||||||
|         from django.utils.encoding import force_bytes |  | ||||||
|         from django.contrib.auth.tokens import default_token_generator |         from django.contrib.auth.tokens import default_token_generator | ||||||
|  |         from django.utils.encoding import force_bytes | ||||||
|         from django.utils.http import urlsafe_base64_encode |         from django.utils.http import urlsafe_base64_encode | ||||||
| 
 | 
 | ||||||
|         result['uid'] = urlsafe_base64_encode(force_bytes(user.pk)) |         result['uid'] = urlsafe_base64_encode(force_bytes(user.pk)) | ||||||
|  | @ -559,6 +561,20 @@ class APIBasicTests(TestsMixin, TestCase): | ||||||
|         resp = self.post(self.logout_url, status=200) |         resp = self.post(self.logout_url, status=200) | ||||||
|         self.assertEqual('', resp.cookies.get('jwt-auth').value) |         self.assertEqual('', resp.cookies.get('jwt-auth').value) | ||||||
| 
 | 
 | ||||||
|  |     @override_settings(JWT_AUTH_REFRESH_COOKIE='jwt-auth-refresh') | ||||||
|  |     @override_settings(REST_USE_JWT=True) | ||||||
|  |     @override_settings(JWT_AUTH_COOKIE='jwt-auth') | ||||||
|  |     def test_logout_jwt_deletes_cookie_refresh(self): | ||||||
|  |         payload = { | ||||||
|  |             "username": self.USERNAME, | ||||||
|  |             "password": self.PASS | ||||||
|  |         } | ||||||
|  |         get_user_model().objects.create_user(self.USERNAME, '', self.PASS) | ||||||
|  |         self.post(self.login_url, data=payload, status_code=200) | ||||||
|  |         resp = self.post(self.logout_url, status=200) | ||||||
|  |         self.assertEqual('', resp.cookies.get('jwt-auth').value) | ||||||
|  |         self.assertEqual('', resp.cookies.get('jwt-auth-refresh').value) | ||||||
|  | 
 | ||||||
|     @override_settings(REST_USE_JWT=True) |     @override_settings(REST_USE_JWT=True) | ||||||
|     @override_settings(JWT_AUTH_COOKIE='jwt-auth') |     @override_settings(JWT_AUTH_COOKIE='jwt-auth') | ||||||
|     @override_settings(REST_FRAMEWORK=dict( |     @override_settings(REST_FRAMEWORK=dict( | ||||||
|  | @ -604,21 +620,15 @@ class APIBasicTests(TestsMixin, TestCase): | ||||||
|         resp = self.post(self.login_url, data=payload, status_code=200) |         resp = self.post(self.login_url, data=payload, status_code=200) | ||||||
|         token = resp.data['refresh_token'] |         token = resp.data['refresh_token'] | ||||||
|         # test refresh token not included in request data |         # test refresh token not included in request data | ||||||
|         resp = self.post(self.logout_url, status=200) |         self.post(self.logout_url, status_code=401) | ||||||
|         self.assertEqual(resp.status_code, 401) |  | ||||||
|         # test token is invalid or expired |         # test token is invalid or expired | ||||||
|         resp = self.post(self.logout_url, status=200, data={'refresh': '1'}) |         self.post(self.logout_url, status_code=401, data={'refresh': '1'}) | ||||||
|         self.assertEqual(resp.status_code, 401) |  | ||||||
|         # test successful logout |         # test successful logout | ||||||
|         resp = self.post(self.logout_url, status=200, data={'refresh': token}) |         self.post(self.logout_url, status_code=200, data={'refresh': token}) | ||||||
|         self.assertEqual(resp.status_code, 200) |  | ||||||
|         # test token is blacklisted |         # test token is blacklisted | ||||||
|         resp = self.post(self.logout_url, status=200, data={'refresh': token}) |         self.post(self.logout_url, status_code=401, data={'refresh': token}) | ||||||
|         self.assertEqual(resp.status_code, 401) |  | ||||||
|         # test other TokenError, AttributeError, TypeError (invalid format) |         # test other TokenError, AttributeError, TypeError (invalid format) | ||||||
|         resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token})) |         self.post(self.logout_url, status_code=500, data=json.dumps({'refresh': token})) | ||||||
|         self.assertEqual(resp.status_code, 500) |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
|     @override_settings(REST_USE_JWT=True) |     @override_settings(REST_USE_JWT=True) | ||||||
|     @override_settings(JWT_AUTH_COOKIE=None) |     @override_settings(JWT_AUTH_COOKIE=None) | ||||||
|  | @ -868,3 +878,38 @@ class APIBasicTests(TestsMixin, TestCase): | ||||||
|         resp = client.post('/protected-view/', csrfparam) |         resp = client.post('/protected-view/', csrfparam) | ||||||
|         self.assertEquals(resp.status_code, 200) |         self.assertEquals(resp.status_code, 200) | ||||||
| 
 | 
 | ||||||
|  |     @override_settings(JWT_AUTH_RETURN_EXPIRATION=True) | ||||||
|  |     @override_settings(REST_USE_JWT=True) | ||||||
|  |     @override_settings(ACCOUNT_LOGOUT_ON_GET=True) | ||||||
|  |     def test_return_expiration(self): | ||||||
|  |         payload = { | ||||||
|  |             "username": self.USERNAME, | ||||||
|  |             "password": self.PASS | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         # create user | ||||||
|  |         get_user_model().objects.create_user(self.USERNAME, '', self.PASS) | ||||||
|  | 
 | ||||||
|  |         resp = self.post(self.login_url, data=payload, status_code=200) | ||||||
|  |         self.assertIn('access_token_expiration', resp.data.keys()) | ||||||
|  |         self.assertIn('refresh_token_expiration', resp.data.keys()) | ||||||
|  | 
 | ||||||
|  |     @override_settings(JWT_AUTH_RETURN_EXPIRATION=True) | ||||||
|  |     @override_settings(REST_USE_JWT=True) | ||||||
|  |     @override_settings(JWT_AUTH_COOKIE='xxx') | ||||||
|  |     @override_settings(ACCOUNT_LOGOUT_ON_GET=True) | ||||||
|  |     @override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx') | ||||||
|  |     @override_settings(JWT_AUTH_REFRESH_COOKIE_PATH='/foo/bar') | ||||||
|  |     def test_refresh_cookie_name(self): | ||||||
|  |         payload = { | ||||||
|  |             "username": self.USERNAME, | ||||||
|  |             "password": self.PASS | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         # create user | ||||||
|  |         get_user_model().objects.create_user(self.USERNAME, '', self.PASS) | ||||||
|  | 
 | ||||||
|  |         resp = self.post(self.login_url, data=payload, status_code=200) | ||||||
|  |         self.assertIn('xxx', resp.cookies.keys()) | ||||||
|  |         self.assertIn('refresh-xxx', resp.cookies.keys()) | ||||||
|  |         self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar') | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ from django.conf import settings | ||||||
| from django.contrib.auth import get_user_model | from django.contrib.auth import get_user_model | ||||||
| from django.contrib.auth import login as django_login | from django.contrib.auth import login as django_login | ||||||
| from django.contrib.auth import logout as django_logout | from django.contrib.auth import logout as django_logout | ||||||
|  | from django.utils import timezone | ||||||
| from django.core.exceptions import ObjectDoesNotExist | from django.core.exceptions import ObjectDoesNotExist | ||||||
| from django.utils.decorators import method_decorator | from django.utils.decorators import method_decorator | ||||||
| from django.utils.translation import ugettext_lazy as _ | from django.utils.translation import ugettext_lazy as _ | ||||||
|  | @ -76,12 +77,12 @@ class LoginView(GenericAPIView): | ||||||
|     def get_response(self): |     def get_response(self): | ||||||
|         serializer_class = self.get_response_serializer() |         serializer_class = self.get_response_serializer() | ||||||
| 
 | 
 | ||||||
|  |         access_token_expiration = None | ||||||
|  |         refresh_token_expiration = None | ||||||
|         if getattr(settings, 'REST_USE_JWT', False): |         if getattr(settings, 'REST_USE_JWT', False): | ||||||
|             from rest_framework_simplejwt.settings import api_settings as jwt_settings |             from rest_framework_simplejwt.settings import api_settings as jwt_settings | ||||||
|             from datetime import datetime |             access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME) | ||||||
| 
 |             refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME) | ||||||
|             access_token_expiration = (datetime.utcnow() + jwt_settings.ACCESS_TOKEN_LIFETIME) |  | ||||||
|             refresh_token_expiration = (datetime.utcnow() + jwt_settings.REFRESH_TOKEN_LIFETIME) |  | ||||||
|             return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False) |             return_expiration_times = getattr(settings, 'JWT_AUTH_RETURN_EXPIRATION', False) | ||||||
| 
 | 
 | ||||||
|             data = { |             data = { | ||||||
|  | @ -170,8 +171,10 @@ class LogoutView(APIView): | ||||||
|         if getattr(settings, 'REST_SESSION_LOGIN', True): |         if getattr(settings, 'REST_SESSION_LOGIN', True): | ||||||
|             django_logout(request) |             django_logout(request) | ||||||
| 
 | 
 | ||||||
|         response = Response({"detail": _("Successfully logged out.")}, |         response = Response( | ||||||
|                             status=status.HTTP_200_OK) |             {"detail": _("Successfully logged out.")}, | ||||||
|  |             status=status.HTTP_200_OK | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if getattr(settings, 'REST_USE_JWT', False): |         if getattr(settings, 'REST_USE_JWT', False): | ||||||
|             # NOTE: this import occurs here rather than at the top level |             # NOTE: this import occurs here rather than at the top level | ||||||
|  | @ -183,41 +186,38 @@ class LogoutView(APIView): | ||||||
|             cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) |             cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) | ||||||
|             if cookie_name: |             if cookie_name: | ||||||
|                 response.delete_cookie(cookie_name) |                 response.delete_cookie(cookie_name) | ||||||
| 
 |  | ||||||
|             refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None) |             refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None) | ||||||
|             if refresh_cookie_name: |             if refresh_cookie_name: | ||||||
|                 response.delete_cookie(refresh_cookie_name) |                 response.delete_cookie(refresh_cookie_name) | ||||||
| 
 | 
 | ||||||
|             elif 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS: |             if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS: | ||||||
|                 # add refresh token to blacklist |                 # add refresh token to blacklist | ||||||
|                 try: |                 try: | ||||||
|                     token = RefreshToken(request.data['refresh']) |                     token = RefreshToken(request.data['refresh']) | ||||||
|                     token.blacklist() |                     token.blacklist() | ||||||
| 
 |  | ||||||
|                 except KeyError: |                 except KeyError: | ||||||
|                     response = Response({"detail": _("Refresh token was not included in request data.")}, |                     response.data = {"detail": _("Refresh token was not included in request data.")} | ||||||
|                                         status=status.HTTP_401_UNAUTHORIZED) |                     response.status_code =status.HTTP_401_UNAUTHORIZED | ||||||
| 
 |  | ||||||
|                 except (TokenError, AttributeError, TypeError) as error: |                 except (TokenError, AttributeError, TypeError) as error: | ||||||
|                     if hasattr(error, 'args'): |                     if hasattr(error, 'args'): | ||||||
|                         if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args: |                         if 'Token is blacklisted' in error.args or 'Token is invalid or expired' in error.args: | ||||||
|                             response = Response({"detail": _(error.args[0])}, |                             response.data = {"detail": _(error.args[0])} | ||||||
|                                                 status=status.HTTP_401_UNAUTHORIZED) |                             response.status_code = status.HTTP_401_UNAUTHORIZED | ||||||
| 
 |  | ||||||
|                         else: |                         else: | ||||||
|                             response = Response({"detail": _("An error has occurred.")}, |                             response.data = {"detail": _("An error has occurred.")} | ||||||
|                                                 status=status.HTTP_500_INTERNAL_SERVER_ERROR) |                             response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR | ||||||
| 
 | 
 | ||||||
|                     else: |                     else: | ||||||
|                         response = Response({"detail": _("An error has occurred.")}, |                         response.data = {"detail": _("An error has occurred.")} | ||||||
|                                             status=status.HTTP_500_INTERNAL_SERVER_ERROR) |                         response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR | ||||||
| 
 | 
 | ||||||
|             else: |             else: | ||||||
|                 response = Response({ |                 message = _( | ||||||
|                     "detail": _("Neither cookies or blacklist are enabled, so the token has not been deleted server " |                     "Neither cookies or blacklist are enabled, so the token " | ||||||
|                                 "side. Please make sure the token is deleted client side." |                     "has not been deleted server side. Please make sure the token is deleted client side." | ||||||
|                                 )}, status=status.HTTP_200_OK) |                 ) | ||||||
| 
 |                 response.data = {"detail": message} | ||||||
|  |                 response.status_code = status.HTTP_200_OK | ||||||
|         return response |         return response | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user