mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-08 06:14:47 +03:00
Merge 9ad8d08683
into 9b56dda918
This commit is contained in:
commit
3199a6f4b5
|
@ -7,6 +7,7 @@ import base64
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
from django.contrib.auth import authenticate, get_user_model
|
from django.contrib.auth import authenticate, get_user_model
|
||||||
|
from django.contrib.auth.models import AnonymousUser
|
||||||
from django.middleware.csrf import CsrfViewMiddleware
|
from django.middleware.csrf import CsrfViewMiddleware
|
||||||
from django.utils.six import text_type
|
from django.utils.six import text_type
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
@ -125,19 +126,21 @@ class SessionAuthentication(BaseAuthentication):
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.enforce_csrf(request)
|
if self.check_csrf(request):
|
||||||
|
# CSRF passed with authenticated user
|
||||||
|
return (user, None)
|
||||||
|
else:
|
||||||
|
return (AnonymousUser(), None)
|
||||||
|
|
||||||
# CSRF passed with authenticated user
|
def check_csrf(self, request):
|
||||||
return (user, None)
|
|
||||||
|
|
||||||
def enforce_csrf(self, request):
|
|
||||||
"""
|
"""
|
||||||
Enforce CSRF validation for session based authentication.
|
return True if csrf is correct.
|
||||||
"""
|
"""
|
||||||
reason = CSRFCheck().process_view(request, None, (), {})
|
reason = CSRFCheck().process_view(request, None, (), {})
|
||||||
if reason:
|
if reason:
|
||||||
# CSRF failed, bail with explicit error message
|
request._csrf_failed_reason = reason
|
||||||
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
|
|
||||||
|
return not reason
|
||||||
|
|
||||||
|
|
||||||
class TokenAuthentication(BaseAuthentication):
|
class TokenAuthentication(BaseAuthentication):
|
||||||
|
|
|
@ -8,6 +8,7 @@ from django.contrib.auth.models import User
|
||||||
from django.shortcuts import redirect
|
from django.shortcuts import redirect
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from rest_framework import exceptions
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.test import (
|
from rest_framework.test import (
|
||||||
|
@ -23,6 +24,21 @@ def view(request):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@api_view(['GET', 'POST'])
|
||||||
|
def authenticated_view(request):
|
||||||
|
if not request.user.is_authenticated():
|
||||||
|
reason = getattr(request, '_csrf_failed_reason', None)
|
||||||
|
if reason:
|
||||||
|
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
|
||||||
|
else:
|
||||||
|
raise exceptions.PermissionDenied()
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
|
||||||
|
'user': request.user.username
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@api_view(['GET', 'POST'])
|
@api_view(['GET', 'POST'])
|
||||||
def session_view(request):
|
def session_view(request):
|
||||||
active_session = request.session.get('active_session', False)
|
active_session = request.session.get('active_session', False)
|
||||||
|
@ -39,6 +55,7 @@ def redirect_view(request):
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
url(r'^view/$', view),
|
url(r'^view/$', view),
|
||||||
|
url(r'^authenticated-view/$', authenticated_view),
|
||||||
url(r'^session-view/$', session_view),
|
url(r'^session-view/$', session_view),
|
||||||
url(r'^redirect-view/$', redirect_view),
|
url(r'^redirect-view/$', redirect_view),
|
||||||
]
|
]
|
||||||
|
@ -104,7 +121,7 @@ class TestAPITestClient(TestCase):
|
||||||
client = APIClient(enforce_csrf_checks=True)
|
client = APIClient(enforce_csrf_checks=True)
|
||||||
User.objects.create_user('example', 'example@example.com', 'password')
|
User.objects.create_user('example', 'example@example.com', 'password')
|
||||||
client.login(username='example', password='password')
|
client.login(username='example', password='password')
|
||||||
response = client.post('/view/')
|
response = client.post('/authenticated-view/')
|
||||||
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
self.assertEqual(response.data, expected)
|
self.assertEqual(response.data, expected)
|
||||||
|
@ -201,9 +218,9 @@ class TestAPIRequestFactory(TestCase):
|
||||||
"""
|
"""
|
||||||
user = User.objects.create_user('example', 'example@example.com', 'password')
|
user = User.objects.create_user('example', 'example@example.com', 'password')
|
||||||
factory = APIRequestFactory(enforce_csrf_checks=True)
|
factory = APIRequestFactory(enforce_csrf_checks=True)
|
||||||
request = factory.post('/view/')
|
request = factory.post('/authenticated-view/')
|
||||||
request.user = user
|
request.user = user
|
||||||
response = view(request)
|
response = authenticated_view(request)
|
||||||
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
self.assertEqual(response.data, expected)
|
self.assertEqual(response.data, expected)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user