From 441b2e9962549cdc7613258a37aa354d3d73dab7 Mon Sep 17 00:00:00 2001 From: Michael <487897+iMerica@users.noreply.github.com> Date: Thu, 19 Nov 2020 09:29:01 -0600 Subject: [PATCH] Adds view for refreshing tokens with cookies (#173) --- dj_rest_auth/jwt_auth.py | 32 +++++++++++++++++++++++++++++++- dj_rest_auth/tests/test_api.py | 20 ++++++++++++++++++++ dj_rest_auth/tests/urls.py | 23 ++++++++++++++--------- dj_rest_auth/urls.py | 13 +++++++------ 4 files changed, 72 insertions(+), 16 deletions(-) diff --git a/dj_rest_auth/jwt_auth.py b/dj_rest_auth/jwt_auth.py index 1611cf5..07c8702 100644 --- a/dj_rest_auth/jwt_auth.py +++ b/dj_rest_auth/jwt_auth.py @@ -1,7 +1,37 @@ from django.conf import settings -from rest_framework_simplejwt.authentication import JWTAuthentication +from django.utils import timezone from rest_framework import exceptions from rest_framework.authentication import CSRFCheck +from rest_framework_simplejwt.authentication import JWTAuthentication + + +def get_refresh_view(): + """ Returns a Token Refresh CBV without a circular import """ + from rest_framework_simplejwt.settings import api_settings as jwt_settings + from rest_framework_simplejwt.views import TokenRefreshView + + class RefreshViewWithCookieSupport(TokenRefreshView): + def post(self, request, *args, **kwargs): + response = super().post(request, *args, **kwargs) + cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) + if cookie_name and response.status_code == 200 and 'access' in response.data: + cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False) + cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True) + cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax') + token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME) + response.set_cookie( + cookie_name, + response.data['access'], + expires=token_expiration, + secure=cookie_secure, + httponly=cookie_httponly, + samesite=cookie_samesite, + ) + + response.data['access_token_expiration'] = token_expiration + return response + return RefreshViewWithCookieSupport + class JWTCookieAuthentication(JWTAuthentication): """ diff --git a/dj_rest_auth/tests/test_api.py b/dj_rest_auth/tests/test_api.py index 64a7f1b..64031f2 100644 --- a/dj_rest_auth/tests/test_api.py +++ b/dj_rest_auth/tests/test_api.py @@ -913,3 +913,23 @@ class APIBasicTests(TestsMixin, TestCase): self.assertIn('xxx', resp.cookies.keys()) self.assertIn('refresh-xxx', resp.cookies.keys()) self.assertEqual(resp.cookies.get('refresh-xxx').get('path'), '/foo/bar') + + @override_settings(JWT_AUTH_RETURN_EXPIRATION=True) + @override_settings(REST_USE_JWT=True) + @override_settings(JWT_AUTH_COOKIE='xxx') + @override_settings(JWT_AUTH_REFRESH_COOKIE='refresh-xxx') + def test_custom_token_refresh_view(self): + payload = { + "username": self.USERNAME, + "password": self.PASS + } + + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + resp = self.post(self.login_url, data=payload, status_code=200) + refresh = resp.data.get('refresh_token') + refresh_resp = self.post( + reverse('token_refresh'), + data=dict(refresh=refresh), + status_code=200 + ) + self.assertIn('xxx', refresh_resp.cookies) diff --git a/dj_rest_auth/tests/urls.py b/dj_rest_auth/tests/urls.py index 3e3e28e..c753356 100644 --- a/dj_rest_auth/tests/urls.py +++ b/dj_rest_auth/tests/urls.py @@ -1,6 +1,16 @@ from allauth.socialaccount.providers.facebook.views import \ FacebookOAuth2Adapter from allauth.socialaccount.providers.twitter.views import TwitterOAuthAdapter +from django.conf.urls import include, url +from django.views.decorators.csrf import ensure_csrf_cookie +from django.views.generic import TemplateView +from rest_framework import permissions +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework_simplejwt.views import TokenVerifyView + +from dj_rest_auth.jwt_auth import get_refresh_view from dj_rest_auth.registration.views import (SocialAccountDisconnectView, SocialAccountListView, SocialConnectView, @@ -8,14 +18,6 @@ from dj_rest_auth.registration.views import (SocialAccountDisconnectView, from dj_rest_auth.social_serializers import (TwitterConnectSerializer, TwitterLoginSerializer) from dj_rest_auth.urls import urlpatterns -from django.conf.urls import include, url -from django.views.generic import TemplateView -from django.utils.decorators import method_decorator -from django.views.decorators.csrf import ensure_csrf_cookie -from rest_framework import permissions -from rest_framework.decorators import api_view -from rest_framework.response import Response -from rest_framework.views import APIView from . import django_urls @@ -64,6 +66,7 @@ def twitter_login_view(request): class TwitterLoginNoAdapter(SocialLoginView): serializer_class = TwitterLoginSerializer + @ensure_csrf_cookie @api_view(['GET']) def get_csrf_cookie(request): @@ -89,4 +92,6 @@ urlpatterns += [ name='social_account_disconnect'), url(r'^accounts/', include('allauth.socialaccount.urls')), url(r'^getcsrf/', get_csrf_cookie, name='getcsrf'), -] + url('^token/verify/', TokenVerifyView.as_view(), name='token_verify'), + url('^token/refresh/', get_refresh_view().as_view(), name='token_refresh'), +] \ No newline at end of file diff --git a/dj_rest_auth/urls.py b/dj_rest_auth/urls.py index 81c6111..7da0d13 100644 --- a/dj_rest_auth/urls.py +++ b/dj_rest_auth/urls.py @@ -1,8 +1,9 @@ +from django.conf import settings +from django.urls import path + from dj_rest_auth.views import (LoginView, LogoutView, PasswordChangeView, PasswordResetConfirmView, PasswordResetView, UserDetailsView) -from django.urls import path -from django.conf import settings urlpatterns = [ # URLs that do not require a session or valid token @@ -16,11 +17,11 @@ urlpatterns = [ ] if getattr(settings, 'REST_USE_JWT', False): - from rest_framework_simplejwt.views import ( - TokenRefreshView, TokenVerifyView, - ) + from rest_framework_simplejwt.views import TokenVerifyView + + from dj_rest_auth.jwt_auth import get_refresh_view urlpatterns += [ path('token/verify/', TokenVerifyView.as_view(), name='token_verify'), - path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), + path('token/refresh/', get_refresh_view().as_view(), name='token_refresh'), ]