Use 401 with Authenticate header, when appropriate

This commit is contained in:
Rob Romano 2012-11-13 23:14:23 -08:00
parent 8953a60196
commit 40442f6cc1
4 changed files with 32 additions and 12 deletions

View File

@ -86,6 +86,7 @@ class Request(object):
self._method = Empty
self._content_type = Empty
self._stream = Empty
self._authenticated = False
if self.parser_context is None:
self.parser_context = {}
@ -288,6 +289,7 @@ class Request(object):
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
self._authenticated = True
return user_auth_tuple
return self._not_authenticated()

View File

@ -25,8 +25,13 @@ class MockView(APIView):
MockView.authentication_classes += (TokenAuthentication,)
class AdminMockView(MockView):
permission_classes = (permissions.IsAdminUser,)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
(r'^admin/$', AdminMockView.as_view()),
)
@ -54,13 +59,14 @@ class BasicAuthTests(TestCase):
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
"""Ensure POSTing form over basic auth without correct credentials is denied with 401"""
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
def test_post_json_no_permissions(self):
"""Ensure POSTing json over basic auth to restricted endpoint is denied with 403"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post('/admin/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 403)
@ -105,10 +111,10 @@ class SessionAuthTests(TestCase):
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
Ensure POSTing form over session authentication without logged is denied with 401
"""
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
class TokenAuthTests(TestCase):
@ -138,13 +144,19 @@ class TokenAuthTests(TestCase):
self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
"""Ensure POSTing form over token auth without correct credentials is denied with 401"""
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
"""Ensure POSTing json over token auth without correct credentials is denied with 401"""
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, 401)
def test_post_json_no_permissions(self):
"""Ensure POSTing json over token auth to restricted endpoint is denied with 403"""
auth = "Token " + self.key
response = self.csrf_client.post('/admin/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 403)
def test_token_has_auto_assigned_key_if_none_provided(self):

View File

@ -109,7 +109,7 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEquals(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):

View File

@ -13,7 +13,7 @@ from rest_framework.compat import View, apply_markdown
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
from rest_framework.exceptions import NotAuthenticated
def _remove_trailing_string(content, trailing):
"""
@ -148,6 +148,8 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
if not request._authenticated:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
def throttled(self, request, wait):
@ -331,6 +333,10 @@ class APIView(View):
return Response({'detail': 'Permission denied'},
status=status.HTTP_403_FORBIDDEN,
exception=True)
elif isinstance(exc, NotAuthenticated):
return Response({'detail': self.get_authenticators()[0].authenticate_header()},
status=status.HTTP_401_FORBIDDEN,
exception=True)
raise
# Note: session based authentication is explicitly CSRF validated,