mirror of
				https://github.com/Tivix/django-rest-auth.git
				synced 2025-11-04 01:27:36 +03:00 
			
		
		
		
	added the ability to customise claims in the jwt token - has tests
JWT claim serializer now can be set to something custom in settings: JWT_TOKEN_CLAIMS_SERIALIZER = myTokenObtainSerializer Ideally JWT_TOKEN_CLAIMS_SERIALIZER would be a key in REST_AUTH_SERIALIZERS and assigned through import_callable, as with the other serializers; however, I could not quite figure out how to implement it that way
This commit is contained in:
		
							parent
							
								
									9dbbef4640
								
							
						
					
					
						commit
						0722ec4aee
					
				| 
						 | 
					@ -18,6 +18,9 @@ try:
 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    from django.core.urlresolvers import reverse
 | 
					    from django.core.urlresolvers import reverse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
 | 
				
			||||||
 | 
					from jwt import decode as decode_jwt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@override_settings(ROOT_URLCONF="tests.urls")
 | 
					@override_settings(ROOT_URLCONF="tests.urls")
 | 
				
			||||||
class APIBasicTests(TestsMixin, TestCase):
 | 
					class APIBasicTests(TestsMixin, TestCase):
 | 
				
			||||||
| 
						 | 
					@ -605,3 +608,66 @@ class APIBasicTests(TestsMixin, TestCase):
 | 
				
			||||||
        # test other TokenError, AttributeError, TypeError (invalid format)
 | 
					        # test other TokenError, AttributeError, TypeError (invalid format)
 | 
				
			||||||
        resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token}))
 | 
					        resp = self.post(self.logout_url, status=200, data=json.dumps({'refresh': token}))
 | 
				
			||||||
        self.assertEqual(resp.status_code, 500)
 | 
					        self.assertEqual(resp.status_code, 500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class TESTTokenObtainPairSerializer(TokenObtainPairSerializer):
 | 
				
			||||||
 | 
					        @classmethod
 | 
				
			||||||
 | 
					        def get_token(cls, user):
 | 
				
			||||||
 | 
					            token = super().get_token(user)
 | 
				
			||||||
 | 
					            # Add custom claims
 | 
				
			||||||
 | 
					            token['name'] = user.username
 | 
				
			||||||
 | 
					            token['email'] = user.email
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @override_settings(REST_USE_JWT=True)
 | 
				
			||||||
 | 
					    @override_settings(JWT_AUTH_COOKIE=None)
 | 
				
			||||||
 | 
					    @override_settings(REST_FRAMEWORK=dict(
 | 
				
			||||||
 | 
					        DEFAULT_AUTHENTICATION_CLASSES=[
 | 
				
			||||||
 | 
					            'dj_rest_auth.utils.JWTCookieAuthentication'
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    ))
 | 
				
			||||||
 | 
					    @override_settings(REST_SESSION_LOGIN=False)
 | 
				
			||||||
 | 
					    @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = TESTTokenObtainPairSerializer)
 | 
				
			||||||
 | 
					    def test_custom_jwt_claims(self):
 | 
				
			||||||
 | 
					        payload = {
 | 
				
			||||||
 | 
					            "username": self.USERNAME,
 | 
				
			||||||
 | 
					            "password": self.PASS
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        get_user_model().objects.create_user(self.USERNAME, self.EMAIL, self.PASS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.post(self.login_url, data=payload, status_code=200)
 | 
				
			||||||
 | 
					        self.assertEqual('access_token' in self.response.json.keys(), True)
 | 
				
			||||||
 | 
					        self.token = self.response.json['access_token']
 | 
				
			||||||
 | 
					        claims = decode_jwt(self.token, settings.SECRET_KEY, algorithms='HS256')
 | 
				
			||||||
 | 
					        self.assertEquals(claims['user_id'], 1)
 | 
				
			||||||
 | 
					        self.assertEquals(claims['name'], 'person')
 | 
				
			||||||
 | 
					        self.assertEquals(claims['email'], 'person1@world.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @override_settings(REST_USE_JWT=True)
 | 
				
			||||||
 | 
					    @override_settings(JWT_AUTH_COOKIE='jwt-auth')
 | 
				
			||||||
 | 
					    @override_settings(REST_FRAMEWORK=dict(
 | 
				
			||||||
 | 
					        DEFAULT_AUTHENTICATION_CLASSES=[
 | 
				
			||||||
 | 
					            'dj_rest_auth.utils.JWTCookieAuthentication'
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    ))
 | 
				
			||||||
 | 
					    @override_settings(REST_SESSION_LOGIN=False)
 | 
				
			||||||
 | 
					    @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = TESTTokenObtainPairSerializer)
 | 
				
			||||||
 | 
					    def test_custom_jwt_claims_cookie_w_authentication(self):
 | 
				
			||||||
 | 
					        payload = {
 | 
				
			||||||
 | 
					            "username": self.USERNAME,
 | 
				
			||||||
 | 
					            "password": self.PASS
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        get_user_model().objects.create_user(self.USERNAME, self.EMAIL, self.PASS)
 | 
				
			||||||
 | 
					        resp = self.post(self.login_url, data=payload, status_code=200)
 | 
				
			||||||
 | 
					        self.assertEqual(['jwt-auth'], list(resp.cookies.keys()))
 | 
				
			||||||
 | 
					        token = resp.cookies.get('jwt-auth').value
 | 
				
			||||||
 | 
					        claims = decode_jwt(token, settings.SECRET_KEY, algorithms='HS256')
 | 
				
			||||||
 | 
					        self.assertEquals(claims['user_id'], 1)
 | 
				
			||||||
 | 
					        self.assertEquals(claims['name'], 'person')
 | 
				
			||||||
 | 
					        self.assertEquals(claims['email'], 'person1@world.com')
 | 
				
			||||||
 | 
					        resp = self.get('/protected-view/')
 | 
				
			||||||
 | 
					        self.assertEquals(resp.status_code, 200)
 | 
				
			||||||
| 
						 | 
					@ -15,18 +15,15 @@ def default_create_token(token_model, user, serializer):
 | 
				
			||||||
    return token
 | 
					    return token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def jwt_encode(user):
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
 | 
					 | 
				
			||||||
    except ImportError:
 | 
					 | 
				
			||||||
        raise ImportError("rest-framework-simplejwt needs to be installed")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    refresh = TokenObtainPairSerializer.get_token(user)
 | 
					 | 
				
			||||||
    return refresh.access_token, refresh
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
 | 
					    from django.conf import settings
 | 
				
			||||||
    from rest_framework_simplejwt.authentication import JWTAuthentication
 | 
					    from rest_framework_simplejwt.authentication import JWTAuthentication
 | 
				
			||||||
 | 
					    from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def jwt_encode(user):
 | 
				
			||||||
 | 
					        TOPS = getattr(settings, 'JWT_TOKEN_CLAIMS_SERIALIZER', TokenObtainPairSerializer)
 | 
				
			||||||
 | 
					        refresh = TOPS.get_token(user)
 | 
				
			||||||
 | 
					        return refresh.access_token, refresh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class JWTCookieAuthentication(JWTAuthentication):
 | 
					    class JWTCookieAuthentication(JWTAuthentication):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -35,7 +32,6 @@ try:
 | 
				
			||||||
        preference to the header).
 | 
					        preference to the header).
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        def authenticate(self, request):
 | 
					        def authenticate(self, request):
 | 
				
			||||||
            from django.conf import settings
 | 
					 | 
				
			||||||
            cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
 | 
					            cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
 | 
				
			||||||
            header = self.get_header(request)
 | 
					            header = self.get_header(request)
 | 
				
			||||||
            if header is None:
 | 
					            if header is None:
 | 
				
			||||||
| 
						 | 
					@ -53,4 +49,4 @@ try:
 | 
				
			||||||
            return self.get_user(validated_token), validated_token
 | 
					            return self.get_user(validated_token), validated_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    pass
 | 
					    raise ImportError("rest-framework-simplejwt needs to be installed")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user