diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb766..4f90d9528 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -2,6 +2,7 @@ from rest_framework.views import APIView from rest_framework import status from rest_framework import parsers from rest_framework import renderers +from rest_framework import exceptions from rest_framework.response import Response from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer @@ -20,7 +21,7 @@ class ObtainAuthToken(APIView): if serializer.is_valid(): token, created = Token.objects.get_or_create(user=serializer.object['user']) return Response({'token': token.key}) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + raise exceptions.TokenAuthenticationError(serializer.errors) obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 829f11382..50e2959e0 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -34,6 +34,13 @@ class DeserializeError(APIException): self.data = dict(errors) +class TokenAuthenticationError(DeserializeError): + """Raised when incorrect data is posted during Token Authentication.""" + # TODO: Change status code to HTTP_401 + # TODO: Make data look like {'detail': 'Reason of failure'} + pass + + class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c8..48e22e797 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import copy +import warnings from django.test import TestCase from django.test.client import RequestFactory @@ -10,6 +11,7 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +from rest_framework.exceptions import ParseError factory = RequestFactory() @@ -19,6 +21,11 @@ class BasicView(APIView): return Response({'method': 'GET'}) def post(self, request, *args, **kwargs): + if 'raise_400_error' in request.DATA: + raise ParseError('Bad request') + if 'return_400_error' in request.DATA: + return Response({'detail': 'Bad request'}, + status=status.HTTP_400_BAD_REQUEST) return Response({'method': 'POST', 'data': request.DATA}) @@ -26,12 +33,13 @@ class BasicView(APIView): def basic_view(request): if request.method == 'GET': return {'method': 'GET'} - elif request.method == 'POST': - return {'method': 'POST', 'data': request.DATA} - elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.DATA} - elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.DATA} + if request.method in ('POST', 'PUT', 'PATCH'): + if 'raise_400_error' in request.DATA: + raise ParseError('Bad request') + if 'return_400_error' in request.DATA: + return Response({'detail': 'Bad request'}, + status=status.HTTP_400_BAD_REQUEST) + return {'method': request.method, 'data': request.DATA} def sanitise_json_error(error_dict): @@ -73,6 +81,29 @@ class ClassBasedViewIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(sanitise_json_error(response.data), expected) + def test_raise_400_error(self): + request = factory.post('/', '{"raise_400_error": true}', + content_type='application/json') + response = self.view(request) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + + def test_return_400_error(self): + request = factory.post('/', '{"return_400_error": true}', + content_type='application/json') + with warnings.catch_warnings(True) as w: + warnings.simplefilter('always') + response = self.view(request) + self.assertEqual(len(w), 1) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): @@ -101,3 +132,26 @@ class FunctionBasedViewIntegrationTests(TestCase): } self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(sanitise_json_error(response.data), expected) + + def test_raise_400_error(self): + request = factory.post('/', '{"raise_400_error": true}', + content_type='application/json') + response = self.view(request) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + + def test_return_400_error(self): + request = factory.post('/', '{"return_400_error": true}', + content_type='application/json') + with warnings.catch_warnings(True) as w: + warnings.simplefilter('always') + response = self.view(request) + self.assertEqual(len(w), 1) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) diff --git a/rest_framework/views.py b/rest_framework/views.py index 193d4fb27..9f36597c9 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,6 +3,8 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals +import warnings + from django.core.exceptions import PermissionDenied from django.http import Http404, HttpResponse from django.utils.datastructures import SortedDict @@ -321,6 +323,11 @@ class APIView(View): response = handler(request, *args, **kwargs) + if response.status_code >= 400: + warnings.warn('Status code >= 400 returned. You should raise' + ' `rest_framework.exceptions.APIException`' + ' instead.', stacklevel=2) + except Exception as exc: response = self.handle_exception(exc)