From 9db71a7a486e3c9fc329718e23af1590fcb92b50 Mon Sep 17 00:00:00 2001 From: David Avsajanishvili Date: Tue, 4 Jun 2013 09:55:39 +0400 Subject: [PATCH 1/5] Add tests for wrong requests --- rest_framework/tests/test_generics.py | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 37734195a..70da08704 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -158,6 +158,19 @@ class TestRootView(TestCase): created = self.objects.get(id=4) self.assertEqual(created.text, 'foobar') + def test_post_wrong_data(self): + """ + POST requests with wrong JSON data should raise HTTP 400 + """ + content = {'id': 999, 'wrongtext': 'foobar'} + request = factory.post('/', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(0): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('text', response.data) + self.assertEqual(response.data['text'], ['This field is required.']) + class TestInstanceView(TestCase): def setUp(self): @@ -303,6 +316,35 @@ class TestInstanceView(TestCase): updated = self.objects.get(id=1) self.assertEqual(updated.text, 'foobar') + def test_put_wrong_data(self): + """ + PUT requests with wrong JSON data should raise HTTP 400 + """ + content = {'id': 999, 'wrongtext': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('text', response.data) + self.assertEqual(response.data['text'], ['This field is required.']) + + def test_patch_wrong_data(self): + """ + PATCH requests with wrong JSON data should raise HTTP 400 + """ + content = {'text': 'foobar' * 20} # too long + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('text', response.data) + self.assertEqual( + response.data['text'], + [u'Ensure this value has at most 100 characters (it has 120).']) + def test_put_to_deleted_instance(self): """ PUT requests to RetrieveUpdateDestroyAPIView should create an object From 39477b219a478d3deee8f30bee8f6627b620a24c Mon Sep 17 00:00:00 2001 From: David Avsajanishvili Date: Tue, 4 Jun 2013 10:08:19 +0400 Subject: [PATCH 2/5] Refactor exceptions, and always return error response by raising exceptoin --- rest_framework/exceptions.py | 40 ++++++++++++++++++++---------------- rest_framework/mixins.py | 6 +++--- rest_framework/views.py | 3 +-- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0c96ecdd5..829f11382 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -11,49 +11,50 @@ from rest_framework import status class APIException(Exception): """ Base class for REST framework exceptions. - Subclasses should provide `.status_code` and `.detail` properties. + Subclasses should provide `.status_code` and `.data` properties. + + The `.data` is a dictionary that usually contains just one + field: "detail". However, some exception classes may override + it. """ - pass + + def __init__(self, detail=None): + self.data = {'detail': detail or self.default_detail} class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = 'Malformed request.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail + +class DeserializeError(APIException): + status_code = status.HTTP_400_BAD_REQUEST + + def __init__(self, errors): + self.data = dict(errors) class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class NotAuthenticated(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Authentication credentials were not provided.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = "Method '%s' not allowed." def __init__(self, method, detail=None): - self.detail = (detail or self.default_detail) % method + self.data = {'detail': (detail or self.default_detail) % method} class NotAcceptable(APIException): @@ -61,7 +62,9 @@ class NotAcceptable(APIException): default_detail = "Could not satisfy the request's Accept header" def __init__(self, detail=None, available_renderers=None): - self.detail = detail or self.default_detail + super(NotAcceptable, self).__init__(detail) + # TODO: self.available_renderers not used anywhere + # across the code self.available_renderers = available_renderers @@ -70,7 +73,7 @@ class UnsupportedMediaType(APIException): default_detail = "Unsupported media type '%s' in request." def __init__(self, media_type, detail=None): - self.detail = (detail or self.default_detail) % media_type + self.data = {'detail': (detail or self.default_detail) % media_type} class Throttled(APIException): @@ -83,9 +86,10 @@ class Throttled(APIException): self.wait = wait and math.ceil(wait) or None if wait is not None: format = detail or self.default_detail + self.extra_detail - self.detail = format % (self.wait, self.wait != 1 and 's' or '') + self.data = {'detail': + format % (self.wait, self.wait != 1 and 's' or '')} else: - self.detail = detail or self.default_detail + self.data = {'detail': detail or self.default_detail} class ConfigurationError(Exception): diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index f11def6d4..b190e5728 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -7,7 +7,7 @@ which allows mixin classes to be composed in interesting ways. from __future__ import unicode_literals from django.http import Http404 -from rest_framework import status +from rest_framework import status, exceptions from rest_framework.response import Response from rest_framework.request import clone_request import warnings @@ -55,7 +55,7 @@ class CreateModelMixin(object): return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + raise exceptions.DeserializeError(serializer.errors) def get_success_headers(self, data): try: @@ -132,7 +132,7 @@ class UpdateModelMixin(object): self.post_save(self.object, created=created) return Response(serializer.data, status=success_status_code) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + raise exceptions.DeserializeError(serializer.errors) def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True diff --git a/rest_framework/views.py b/rest_framework/views.py index e1b6705b6..193d4fb27 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -283,8 +283,7 @@ class APIView(View): exc.status_code = status.HTTP_403_FORBIDDEN if isinstance(exc, exceptions.APIException): - return Response({'detail': exc.detail}, - status=exc.status_code, + return Response(exc.data, status=exc.status_code, exception=True) elif isinstance(exc, Http404): return Response({'detail': 'Not found'}, From c8bed3562145a463b2686bbf9d77fca11f671972 Mon Sep 17 00:00:00 2001 From: David Avsajanishvili Date: Tue, 4 Jun 2013 11:25:26 +0400 Subject: [PATCH 3/5] Depreciate returning status_code >= 400 and test it --- rest_framework/authtoken/views.py | 3 +- rest_framework/exceptions.py | 7 ++++ rest_framework/tests/test_views.py | 66 +++++++++++++++++++++++++++--- rest_framework/views.py | 7 ++++ 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb766..4f90d9528 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -2,6 +2,7 @@ from rest_framework.views import APIView from rest_framework import status from rest_framework import parsers from rest_framework import renderers +from rest_framework import exceptions from rest_framework.response import Response from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer @@ -20,7 +21,7 @@ class ObtainAuthToken(APIView): if serializer.is_valid(): token, created = Token.objects.get_or_create(user=serializer.object['user']) return Response({'token': token.key}) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + raise exceptions.TokenAuthenticationError(serializer.errors) obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 829f11382..50e2959e0 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -34,6 +34,13 @@ class DeserializeError(APIException): self.data = dict(errors) +class TokenAuthenticationError(DeserializeError): + """Raised when incorrect data is posted during Token Authentication.""" + # TODO: Change status code to HTTP_401 + # TODO: Make data look like {'detail': 'Reason of failure'} + pass + + class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c8..48e22e797 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import copy +import warnings from django.test import TestCase from django.test.client import RequestFactory @@ -10,6 +11,7 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +from rest_framework.exceptions import ParseError factory = RequestFactory() @@ -19,6 +21,11 @@ class BasicView(APIView): return Response({'method': 'GET'}) def post(self, request, *args, **kwargs): + if 'raise_400_error' in request.DATA: + raise ParseError('Bad request') + if 'return_400_error' in request.DATA: + return Response({'detail': 'Bad request'}, + status=status.HTTP_400_BAD_REQUEST) return Response({'method': 'POST', 'data': request.DATA}) @@ -26,12 +33,13 @@ class BasicView(APIView): def basic_view(request): if request.method == 'GET': return {'method': 'GET'} - elif request.method == 'POST': - return {'method': 'POST', 'data': request.DATA} - elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.DATA} - elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.DATA} + if request.method in ('POST', 'PUT', 'PATCH'): + if 'raise_400_error' in request.DATA: + raise ParseError('Bad request') + if 'return_400_error' in request.DATA: + return Response({'detail': 'Bad request'}, + status=status.HTTP_400_BAD_REQUEST) + return {'method': request.method, 'data': request.DATA} def sanitise_json_error(error_dict): @@ -73,6 +81,29 @@ class ClassBasedViewIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(sanitise_json_error(response.data), expected) + def test_raise_400_error(self): + request = factory.post('/', '{"raise_400_error": true}', + content_type='application/json') + response = self.view(request) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + + def test_return_400_error(self): + request = factory.post('/', '{"return_400_error": true}', + content_type='application/json') + with warnings.catch_warnings(True) as w: + warnings.simplefilter('always') + response = self.view(request) + self.assertEqual(len(w), 1) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): @@ -101,3 +132,26 @@ class FunctionBasedViewIntegrationTests(TestCase): } self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(sanitise_json_error(response.data), expected) + + def test_raise_400_error(self): + request = factory.post('/', '{"raise_400_error": true}', + content_type='application/json') + response = self.view(request) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) + + def test_return_400_error(self): + request = factory.post('/', '{"return_400_error": true}', + content_type='application/json') + with warnings.catch_warnings(True) as w: + warnings.simplefilter('always') + response = self.view(request) + self.assertEqual(len(w), 1) + expected = { + 'detail': 'Bad request' + } + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) diff --git a/rest_framework/views.py b/rest_framework/views.py index 193d4fb27..9f36597c9 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,6 +3,8 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals +import warnings + from django.core.exceptions import PermissionDenied from django.http import Http404, HttpResponse from django.utils.datastructures import SortedDict @@ -321,6 +323,11 @@ class APIView(View): response = handler(request, *args, **kwargs) + if response.status_code >= 400: + warnings.warn('Status code >= 400 returned. You should raise' + ' `rest_framework.exceptions.APIException`' + ' instead.', stacklevel=2) + except Exception as exc: response = self.handle_exception(exc) From ae9ec60ce6291333483893dcb878d1c839812315 Mon Sep 17 00:00:00 2001 From: David Avsajanishvili Date: Tue, 4 Jun 2013 13:13:04 +0400 Subject: [PATCH 4/5] Implement customizable exception handler --- rest_framework/handlers.py | 43 +++++++++++++++++++++++++++ rest_framework/settings.py | 4 +++ rest_framework/tests/test_generics.py | 2 +- rest_framework/views.py | 36 ++-------------------- 4 files changed, 50 insertions(+), 35 deletions(-) create mode 100644 rest_framework/handlers.py diff --git a/rest_framework/handlers.py b/rest_framework/handlers.py new file mode 100644 index 000000000..3dd54510c --- /dev/null +++ b/rest_framework/handlers.py @@ -0,0 +1,43 @@ +"""Default handlers, configurable in settings""" + +from rest_framework.response import Response +from rest_framework import exceptions +from rest_framework import status +from django.core.exceptions import PermissionDenied +from django.http import Http404 + + +def handle_exception(view_instance, exc): + """ + Default exception handler for APIView. + + Handle any exception that occurs, by returning an appropriate response, + or re-raising the error. + """ + if isinstance(exc, exceptions.Throttled): + # Throttle wait header + view_instance.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + + if isinstance(exc, (exceptions.NotAuthenticated, + exceptions.AuthenticationFailed)): + # WWW-Authenticate header for 401 responses, else coerce to 403 + auth_header = view_instance.get_authenticate_header( + view_instance.request) + + if auth_header: + view_instance.headers['WWW-Authenticate'] = auth_header + else: + exc.status_code = status.HTTP_403_FORBIDDEN + + if isinstance(exc, exceptions.APIException): + return Response(exc.data, status=exc.status_code, + exception=True) + elif isinstance(exc, Http404): + return Response({'detail': 'Not found'}, + status=status.HTTP_404_NOT_FOUND, + exception=True) + elif isinstance(exc, PermissionDenied): + return Response({'detail': 'Permission denied'}, + status=status.HTTP_403_FORBIDDEN, + exception=True) + raise diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511aca..769ff2178 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -59,6 +59,9 @@ DEFAULTS = { 'rest_framework.pagination.PaginationSerializer', 'DEFAULT_FILTER_BACKENDS': (), + # Exception handling + 'DEFAULT_EXCEPTION_HANDLER': 'rest_framework.handlers.handle_exception', + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, @@ -117,6 +120,7 @@ IMPORT_STRINGS = ( 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', + 'DEFAULT_EXCEPTION_HANDLER', ) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 70da08704..7dee3b78a 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -343,7 +343,7 @@ class TestInstanceView(TestCase): self.assertIn('text', response.data) self.assertEqual( response.data['text'], - [u'Ensure this value has at most 100 characters (it has 120).']) + ['Ensure this value has at most 100 characters (it has 120).']) def test_put_to_deleted_instance(self): """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 9f36597c9..1304cebe7 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -5,8 +5,7 @@ from __future__ import unicode_literals import warnings -from django.core.exceptions import PermissionDenied -from django.http import Http404, HttpResponse +from django.http import HttpResponse from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions @@ -26,6 +25,7 @@ class APIView(View): throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + handle_exception = api_settings.DEFAULT_EXCEPTION_HANDLER @classmethod def as_view(cls, **initkwargs): @@ -265,38 +265,6 @@ class APIView(View): return response - def handle_exception(self, exc): - """ - Handle any exception that occurs, by returning an appropriate response, - or re-raising the error. - """ - if isinstance(exc, exceptions.Throttled): - # Throttle wait header - self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait - - if isinstance(exc, (exceptions.NotAuthenticated, - exceptions.AuthenticationFailed)): - # WWW-Authenticate header for 401 responses, else coerce to 403 - auth_header = self.get_authenticate_header(self.request) - - if auth_header: - self.headers['WWW-Authenticate'] = auth_header - else: - exc.status_code = status.HTTP_403_FORBIDDEN - - if isinstance(exc, exceptions.APIException): - return Response(exc.data, status=exc.status_code, - exception=True) - elif isinstance(exc, Http404): - return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND, - exception=True) - elif isinstance(exc, PermissionDenied): - return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN, - exception=True) - raise - # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. @csrf_exempt From 74643b08b67c3ae79fc33c20234a951f1993d1f0 Mon Sep 17 00:00:00 2001 From: David Avsajanishvili Date: Tue, 4 Jun 2013 13:42:31 +0400 Subject: [PATCH 5/5] Fix catch_warnings for Python 3 --- rest_framework/tests/test_views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 48e22e797..095788b42 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -94,7 +94,7 @@ class ClassBasedViewIntegrationTests(TestCase): def test_return_400_error(self): request = factory.post('/', '{"return_400_error": true}', content_type='application/json') - with warnings.catch_warnings(True) as w: + with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') response = self.view(request) self.assertEqual(len(w), 1) @@ -146,7 +146,7 @@ class FunctionBasedViewIntegrationTests(TestCase): def test_return_400_error(self): request = factory.post('/', '{"return_400_error": true}', content_type='application/json') - with warnings.catch_warnings(True) as w: + with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') response = self.view(request) self.assertEqual(len(w), 1)