diff --git a/dj_rest_auth/jwt_auth.py b/dj_rest_auth/jwt_auth.py index 39f9c78..1611cf5 100644 --- a/dj_rest_auth/jwt_auth.py +++ b/dj_rest_auth/jwt_auth.py @@ -1,6 +1,7 @@ from django.conf import settings from rest_framework_simplejwt.authentication import JWTAuthentication - +from rest_framework import exceptions +from rest_framework.authentication import CSRFCheck class JWTCookieAuthentication(JWTAuthentication): """ @@ -8,6 +9,17 @@ class JWTCookieAuthentication(JWTAuthentication): token provided in a request cookie (and through the header as normal, with a preference to the header). """ + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication. + """ + check = CSRFCheck() + # populates request.META['CSRF_COOKIE'], which is used in process_view() + check.process_request(request) + reason = check.process_view(request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) def authenticate(self, request): cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None) @@ -15,6 +27,10 @@ class JWTCookieAuthentication(JWTAuthentication): if header is None: if cookie_name: raw_token = request.COOKIES.get(cookie_name) + if getattr(settings, 'JWT_AUTH_COOKIE_ENFORCE_CSRF_ON_UNAUTHENTICATED', False): #True at your own risk + self.enforce_csrf(request) + elif raw_token is not None and getattr(settings, 'JWT_AUTH_COOKIE_USE_CSRF', False): + self.enforce_csrf(request) else: return None else: diff --git a/dj_rest_auth/tests/test_api.py b/dj_rest_auth/tests/test_api.py index 3c1ff0e..a8c2048 100644 --- a/dj_rest_auth/tests/test_api.py +++ b/dj_rest_auth/tests/test_api.py @@ -668,4 +668,102 @@ class APIBasicTests(TestsMixin, TestCase): self.assertEquals(claims['name'], 'person') self.assertEquals(claims['email'], 'person1@world.com') resp = self.get('/protected-view/') - self.assertEquals(resp.status_code, 200) \ No newline at end of file + self.assertEquals(resp.status_code, 200) + + + @override_settings(REST_USE_JWT=True) + @override_settings(JWT_AUTH_COOKIE='jwt-auth') + @override_settings(JWT_AUTH_COOKIE_USE_CSRF=True) + @override_settings(JWT_AUTH_COOKIE_ENFORCE_CSRF_ON_UNAUTHENTICATED=False) + @override_settings(REST_FRAMEWORK=dict( + DEFAULT_AUTHENTICATION_CLASSES=[ + 'dj_rest_auth.jwt_auth.JWTCookieAuthentication' + ] + )) + @override_settings(REST_SESSION_LOGIN=False) + @override_settings(CSRF_COOKIE_SECURE =True) + @override_settings(CSRF_COOKIE_HTTPONLY =True) + def test_csrf_wo_login_csrf_enforcement(self): + from .mixins import APIClient + payload = { + "username": self.USERNAME, + "password": self.PASS + } + client = APIClient(enforce_csrf_checks=True) + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + + response = client.get(reverse("getcsrf")) + csrftoken = client.cookies['csrftoken'].value + + resp = client.post(self.login_url, payload) + self.assertTrue('jwt-auth' in list(client.cookies.keys())) + self.assertTrue('csrftoken' in list(client.cookies.keys())) + self.assertEquals(resp.status_code, 200) + + ## TEST WITH JWT AUTH HEADER + jwtclient = APIClient(enforce_csrf_checks=True) + token = resp.data['access_token'] + resp = jwtclient.get('/protected-view/') + self.assertEquals(resp.status_code, 403) + resp = jwtclient.get('/protected-view/', HTTP_AUTHORIZATION='Bearer '+token) + self.assertEquals(resp.status_code, 200) + resp = jwtclient.post('/protected-view/', {}) + self.assertEquals(resp.status_code, 403) + resp = jwtclient.post('/protected-view/', {}, HTTP_AUTHORIZATION='Bearer '+token) + self.assertEquals(resp.status_code, 200) + + ## TEST WITH COOKIES + #fail w/o csrftoken in payload + resp = client.post('/protected-view/', {}) + self.assertEquals(resp.status_code, 403) + + csrfparam = {"csrfmiddlewaretoken": csrftoken} + resp = client.post('/protected-view/', csrfparam) + self.assertEquals(resp.status_code, 200) + + + @override_settings(REST_USE_JWT=True) + @override_settings(JWT_AUTH_COOKIE='jwt-auth') + @override_settings(JWT_AUTH_COOKIE_USE_CSRF=True) + @override_settings(JWT_AUTH_COOKIE_ENFORCE_CSRF_ON_UNAUTHENTICATED=True) #True at your own risk + @override_settings(REST_FRAMEWORK=dict( + DEFAULT_AUTHENTICATION_CLASSES=[ + 'dj_rest_auth.jwt_auth.JWTCookieAuthentication' + ] + )) + @override_settings(REST_SESSION_LOGIN=False) + @override_settings(CSRF_COOKIE_SECURE =True) + @override_settings(CSRF_COOKIE_HTTPONLY =True) + def test_csrf_w_login_csrf_enforcement(self): + from .mixins import APIClient + payload = { + "username": self.USERNAME, + "password": self.PASS + } + client = APIClient(enforce_csrf_checks=True) + get_user_model().objects.create_user(self.USERNAME, '', self.PASS) + + response = client.get(reverse("getcsrf")) + csrftoken = client.cookies['csrftoken'].value + + #fail w/o csrftoken in payload + resp = client.post(self.login_url, payload) + self.assertEquals(resp.status_code, 403) + + payload['csrfmiddlewaretoken'] = csrftoken + resp = client.post(self.login_url, payload) + self.assertTrue('jwt-auth' in list(client.cookies.keys())) + self.assertTrue('csrftoken' in list(client.cookies.keys())) + self.assertEquals(resp.status_code, 200) + + ## TEST WITH JWT AUTH HEADER does not make sense + + ## TEST WITH COOKIES + #fail w/o csrftoken in payload + resp = client.post('/protected-view/', {}) + self.assertEquals(resp.status_code, 403) + + csrfparam = {"csrfmiddlewaretoken": csrftoken} + resp = client.post('/protected-view/', csrfparam) + self.assertEquals(resp.status_code, 200) + diff --git a/dj_rest_auth/tests/urls.py b/dj_rest_auth/tests/urls.py index f1796d6..3e3e28e 100644 --- a/dj_rest_auth/tests/urls.py +++ b/dj_rest_auth/tests/urls.py @@ -10,6 +10,8 @@ from dj_rest_auth.social_serializers import (TwitterConnectSerializer, from dj_rest_auth.urls import urlpatterns from django.conf.urls import include, url from django.views.generic import TemplateView +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import ensure_csrf_cookie from rest_framework import permissions from rest_framework.decorators import api_view from rest_framework.response import Response @@ -24,6 +26,9 @@ class ExampleProtectedView(APIView): def get(self, *args, **kwargs): return Response(dict(success=True)) + def post(self, *args, **kwargs): + return Response(dict(success=True)) + class FacebookLogin(SocialLoginView): adapter_class = FacebookOAuth2Adapter @@ -59,6 +64,11 @@ def twitter_login_view(request): class TwitterLoginNoAdapter(SocialLoginView): serializer_class = TwitterLoginSerializer +@ensure_csrf_cookie +@api_view(['GET']) +def get_csrf_cookie(request): + return Response() + urlpatterns += [ url(r'^rest-registration/', include('dj_rest_auth.registration.urls')), @@ -77,5 +87,6 @@ urlpatterns += [ url(r'^protected-view/$', ExampleProtectedView.as_view()), url(r'^socialaccounts/(?P\d+)/disconnect/$', SocialAccountDisconnectView.as_view(), name='social_account_disconnect'), - url(r'^accounts/', include('allauth.socialaccount.urls')) + url(r'^accounts/', include('allauth.socialaccount.urls')), + url(r'^getcsrf/', get_csrf_cookie, name='getcsrf'), ]