diff --git a/dj_rest_auth/tests/test_api.py b/dj_rest_auth/tests/test_api.py index 9554d05..f25f5fc 100644 --- a/dj_rest_auth/tests/test_api.py +++ b/dj_rest_auth/tests/test_api.py @@ -21,6 +21,16 @@ except ImportError: from rest_framework_simplejwt.serializers import TokenObtainPairSerializer from jwt import decode as decode_jwt +class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): + @classmethod + def get_token(cls, user): + token = super().get_token(user) + # Add custom claims + token['name'] = user.username + token['email'] = user.email + + return token + @override_settings(ROOT_URLCONF="tests.urls") class APIBasicTests(TestsMixin, TestCase): @@ -610,18 +620,6 @@ class APIBasicTests(TestsMixin, TestCase): self.assertEqual(resp.status_code, 500) - - class TESTTokenObtainPairSerializer(TokenObtainPairSerializer): - @classmethod - def get_token(cls, user): - token = super().get_token(user) - # Add custom claims - token['name'] = user.username - token['email'] = user.email - - return token - - @override_settings(REST_USE_JWT=True) @override_settings(JWT_AUTH_COOKIE=None) @override_settings(REST_FRAMEWORK=dict( @@ -630,7 +628,7 @@ class APIBasicTests(TestsMixin, TestCase): ] )) @override_settings(REST_SESSION_LOGIN=False) - @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = TESTTokenObtainPairSerializer) + @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = 'tests.test_api.TESTTokenObtainPairSerializer') def test_custom_jwt_claims(self): payload = { "username": self.USERNAME, @@ -655,7 +653,7 @@ class APIBasicTests(TestsMixin, TestCase): ] )) @override_settings(REST_SESSION_LOGIN=False) - @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = TESTTokenObtainPairSerializer) + @override_settings(JWT_TOKEN_CLAIMS_SERIALIZER = 'tests.test_api.TESTTokenObtainPairSerializer') def test_custom_jwt_claims_cookie_w_authentication(self): payload = { "username": self.USERNAME, diff --git a/dj_rest_auth/utils.py b/dj_rest_auth/utils.py index 5f229af..fa971de 100644 --- a/dj_rest_auth/utils.py +++ b/dj_rest_auth/utils.py @@ -21,7 +21,7 @@ try: from rest_framework_simplejwt.serializers import TokenObtainPairSerializer def jwt_encode(user): - TOPS = getattr(settings, 'JWT_TOKEN_CLAIMS_SERIALIZER', TokenObtainPairSerializer) + TOPS = import_callable(getattr(settings, 'JWT_TOKEN_CLAIMS_SERIALIZER', TokenObtainPairSerializer)) refresh = TOPS.get_token(user) return refresh.access_token, refresh