This commit is contained in:
lexdene 2016-05-22 15:23:06 +00:00
commit 3199a6f4b5
2 changed files with 31 additions and 11 deletions

View File

@ -7,6 +7,7 @@ import base64
import binascii import binascii
from django.contrib.auth import authenticate, get_user_model from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AnonymousUser
from django.middleware.csrf import CsrfViewMiddleware from django.middleware.csrf import CsrfViewMiddleware
from django.utils.six import text_type from django.utils.six import text_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -125,19 +126,21 @@ class SessionAuthentication(BaseAuthentication):
if not user or not user.is_active: if not user or not user.is_active:
return None return None
self.enforce_csrf(request) if self.check_csrf(request):
# CSRF passed with authenticated user
return (user, None)
else:
return (AnonymousUser(), None)
# CSRF passed with authenticated user def check_csrf(self, request):
return (user, None)
def enforce_csrf(self, request):
""" """
Enforce CSRF validation for session based authentication. return True if csrf is correct.
""" """
reason = CSRFCheck().process_view(request, None, (), {}) reason = CSRFCheck().process_view(request, None, (), {})
if reason: if reason:
# CSRF failed, bail with explicit error message request._csrf_failed_reason = reason
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
return not reason
class TokenAuthentication(BaseAuthentication): class TokenAuthentication(BaseAuthentication):

View File

@ -8,6 +8,7 @@ from django.contrib.auth.models import User
from django.shortcuts import redirect from django.shortcuts import redirect
from django.test import TestCase from django.test import TestCase
from rest_framework import exceptions
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import ( from rest_framework.test import (
@ -23,6 +24,21 @@ def view(request):
}) })
@api_view(['GET', 'POST'])
def authenticated_view(request):
if not request.user.is_authenticated():
reason = getattr(request, '_csrf_failed_reason', None)
if reason:
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
else:
raise exceptions.PermissionDenied()
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
'user': request.user.username
})
@api_view(['GET', 'POST']) @api_view(['GET', 'POST'])
def session_view(request): def session_view(request):
active_session = request.session.get('active_session', False) active_session = request.session.get('active_session', False)
@ -39,6 +55,7 @@ def redirect_view(request):
urlpatterns = [ urlpatterns = [
url(r'^view/$', view), url(r'^view/$', view),
url(r'^authenticated-view/$', authenticated_view),
url(r'^session-view/$', session_view), url(r'^session-view/$', session_view),
url(r'^redirect-view/$', redirect_view), url(r'^redirect-view/$', redirect_view),
] ]
@ -104,7 +121,7 @@ class TestAPITestClient(TestCase):
client = APIClient(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
User.objects.create_user('example', 'example@example.com', 'password') User.objects.create_user('example', 'example@example.com', 'password')
client.login(username='example', password='password') client.login(username='example', password='password')
response = client.post('/view/') response = client.post('/authenticated-view/')
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected) self.assertEqual(response.data, expected)
@ -201,9 +218,9 @@ class TestAPIRequestFactory(TestCase):
""" """
user = User.objects.create_user('example', 'example@example.com', 'password') user = User.objects.create_user('example', 'example@example.com', 'password')
factory = APIRequestFactory(enforce_csrf_checks=True) factory = APIRequestFactory(enforce_csrf_checks=True)
request = factory.post('/view/') request = factory.post('/authenticated-view/')
request.user = user request.user = user
response = view(request) response = authenticated_view(request)
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected) self.assertEqual(response.data, expected)