diff --git a/dj_rest_auth/utils.py b/dj_rest_auth/utils.py index 7011c8a..4ecfbc1 100644 --- a/dj_rest_auth/utils.py +++ b/dj_rest_auth/utils.py @@ -19,7 +19,14 @@ def default_create_token(token_model, user, serializer): def jwt_encode(user): from rest_framework_simplejwt.serializers import TokenObtainPairSerializer - TOPS = import_callable(getattr(settings, 'JWT_TOKEN_CLAIMS_SERIALIZER', TokenObtainPairSerializer)) + rest_auth_serializers = getattr(settings, 'REST_AUTH_SERIALIZERS', {}) + + JWTTokenClaimsSerializer = rest_auth_serializers.get( + 'JWT_TOKEN_CLAIMS_SERIALIZER', + TokenObtainPairSerializer + ) + + TOPS = import_callable(JWTTokenClaimsSerializer) refresh = TOPS.get_token(user) return refresh.access_token, refresh