This commit is contained in:
lexdene 2016-05-22 15:23:06 +00:00
commit 3199a6f4b5
2 changed files with 31 additions and 11 deletions

View File

@ -7,6 +7,7 @@ import base64
import binascii
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AnonymousUser
from django.middleware.csrf import CsrfViewMiddleware
from django.utils.six import text_type
from django.utils.translation import ugettext_lazy as _
@ -125,19 +126,21 @@ class SessionAuthentication(BaseAuthentication):
if not user or not user.is_active:
return None
self.enforce_csrf(request)
if self.check_csrf(request):
# CSRF passed with authenticated user
return (user, None)
else:
return (AnonymousUser(), None)
# CSRF passed with authenticated user
return (user, None)
def enforce_csrf(self, request):
def check_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
return True if csrf is correct.
"""
reason = CSRFCheck().process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
request._csrf_failed_reason = reason
return not reason
class TokenAuthentication(BaseAuthentication):

View File

@ -8,6 +8,7 @@ from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.test import TestCase
from rest_framework import exceptions
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.test import (
@ -23,6 +24,21 @@ def view(request):
})
@api_view(['GET', 'POST'])
def authenticated_view(request):
if not request.user.is_authenticated():
reason = getattr(request, '_csrf_failed_reason', None)
if reason:
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
else:
raise exceptions.PermissionDenied()
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
'user': request.user.username
})
@api_view(['GET', 'POST'])
def session_view(request):
active_session = request.session.get('active_session', False)
@ -39,6 +55,7 @@ def redirect_view(request):
urlpatterns = [
url(r'^view/$', view),
url(r'^authenticated-view/$', authenticated_view),
url(r'^session-view/$', session_view),
url(r'^redirect-view/$', redirect_view),
]
@ -104,7 +121,7 @@ class TestAPITestClient(TestCase):
client = APIClient(enforce_csrf_checks=True)
User.objects.create_user('example', 'example@example.com', 'password')
client.login(username='example', password='password')
response = client.post('/view/')
response = client.post('/authenticated-view/')
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)
@ -201,9 +218,9 @@ class TestAPIRequestFactory(TestCase):
"""
user = User.objects.create_user('example', 'example@example.com', 'password')
factory = APIRequestFactory(enforce_csrf_checks=True)
request = factory.post('/view/')
request = factory.post('/authenticated-view/')
request.user = user
response = view(request)
response = authenticated_view(request)
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)