From 40442f6cc1bf800c050ce199c6674dc0ef7f628c Mon Sep 17 00:00:00 2001 From: Rob Romano Date: Tue, 13 Nov 2012 23:14:23 -0800 Subject: [PATCH] Use 401 with Authenticate header, when appropriate --- rest_framework/request.py | 2 ++ rest_framework/tests/authentication.py | 32 ++++++++++++++++++-------- rest_framework/tests/decorators.py | 2 +- rest_framework/views.py | 8 ++++++- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index a1827ba48..8b714380a 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object): self._method = Empty self._content_type = Empty self._stream = Empty + self._authenticated = False if self.parser_context is None: self.parser_context = {} @@ -288,6 +289,7 @@ class Request(object): for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: + self._authenticated = True return user_auth_tuple return self._not_authenticated() diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8ab4c4e40..c8b47f56e 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -25,8 +25,13 @@ class MockView(APIView): MockView.authentication_classes += (TokenAuthentication,) +class AdminMockView(MockView): + permission_classes = (permissions.IsAdminUser,) + + urlpatterns = patterns('', (r'^$', MockView.as_view()), + (r'^admin/$', AdminMockView.as_view()), ) @@ -54,13 +59,14 @@ class BasicAuthTests(TestCase): self.assertEqual(response.status_code, 200) def test_post_form_failing_basic_auth(self): - """Ensure POSTing form over basic auth without correct credentials fails""" + """Ensure POSTing form over basic auth without correct credentials is denied with 401""" response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) - def test_post_json_failing_basic_auth(self): - """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') + def test_post_json_no_permissions(self): + """Ensure POSTing json over basic auth to restricted endpoint is denied with 403""" + auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() + response = self.csrf_client.post('/admin/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) @@ -105,10 +111,10 @@ class SessionAuthTests(TestCase): def test_post_form_session_auth_failing(self): """ - Ensure POSTing form over session authentication without logged in user fails. + Ensure POSTing form over session authentication without logged is denied with 401 """ response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) class TokenAuthTests(TestCase): @@ -138,13 +144,19 @@ class TokenAuthTests(TestCase): self.assertEqual(response.status_code, 200) def test_post_form_failing_token_auth(self): - """Ensure POSTing form over token auth without correct credentials fails""" + """Ensure POSTing form over token auth without correct credentials is denied with 401""" response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) def test_post_json_failing_token_auth(self): - """Ensure POSTing json over token auth without correct credentials fails""" + """Ensure POSTing json over token auth without correct credentials is denied with 401""" response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) + + def test_post_json_no_permissions(self): + """Ensure POSTing json over token auth to restricted endpoint is denied with 403""" + auth = "Token " + self.key + response = self.csrf_client.post('/admin/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) def test_token_has_auto_assigned_key_if_none_provided(self): diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 41864d71e..f4a7a20c2 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -109,7 +109,7 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEquals(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_throttle_classes(self): class OncePerDayUserThrottle(UserRateThrottle): diff --git a/rest_framework/views.py b/rest_framework/views.py index 1afbd6974..3f5ec2caf 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -13,7 +13,7 @@ from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings - +from rest_framework.exceptions import NotAuthenticated def _remove_trailing_string(content, trailing): """ @@ -148,6 +148,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if not request._authenticated: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait): @@ -331,6 +333,10 @@ class APIView(View): return Response({'detail': 'Permission denied'}, status=status.HTTP_403_FORBIDDEN, exception=True) + elif isinstance(exc, NotAuthenticated): + return Response({'detail': self.get_authenticators()[0].authenticate_header()}, + status=status.HTTP_401_FORBIDDEN, + exception=True) raise # Note: session based authentication is explicitly CSRF validated,