This commit is contained in:
David Avsajanishvili 2013-06-04 14:20:21 -07:00
commit 7b91afb616
8 changed files with 192 additions and 63 deletions

View File

@ -2,6 +2,7 @@ from rest_framework.views import APIView
from rest_framework import status
from rest_framework import parsers
from rest_framework import renderers
from rest_framework import exceptions
from rest_framework.response import Response
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
@ -20,7 +21,7 @@ class ObtainAuthToken(APIView):
if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user'])
return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
raise exceptions.TokenAuthenticationError(serializer.errors)
obtain_auth_token = ObtainAuthToken.as_view()

View File

@ -11,49 +11,57 @@ from rest_framework import status
class APIException(Exception):
"""
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.detail` properties.
Subclasses should provide `.status_code` and `.data` properties.
The `.data` is a dictionary that usually contains just one
field: "detail". However, some exception classes may override
it.
"""
pass
def __init__(self, detail=None):
self.data = {'detail': detail or self.default_detail}
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class DeserializeError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, errors):
self.data = dict(errors)
class TokenAuthenticationError(DeserializeError):
"""Raised when incorrect data is posted during Token Authentication."""
# TODO: Change status code to HTTP_401
# TODO: Make data look like {'detail': 'Reason of failure'}
pass
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Authentication credentials were not provided.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = "Method '%s' not allowed."
def __init__(self, method, detail=None):
self.detail = (detail or self.default_detail) % method
self.data = {'detail': (detail or self.default_detail) % method}
class NotAcceptable(APIException):
@ -61,7 +69,9 @@ class NotAcceptable(APIException):
default_detail = "Could not satisfy the request's Accept header"
def __init__(self, detail=None, available_renderers=None):
self.detail = detail or self.default_detail
super(NotAcceptable, self).__init__(detail)
# TODO: self.available_renderers not used anywhere
# across the code
self.available_renderers = available_renderers
@ -70,7 +80,7 @@ class UnsupportedMediaType(APIException):
default_detail = "Unsupported media type '%s' in request."
def __init__(self, media_type, detail=None):
self.detail = (detail or self.default_detail) % media_type
self.data = {'detail': (detail or self.default_detail) % media_type}
class Throttled(APIException):
@ -83,9 +93,10 @@ class Throttled(APIException):
self.wait = wait and math.ceil(wait) or None
if wait is not None:
format = detail or self.default_detail + self.extra_detail
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
self.data = {'detail':
format % (self.wait, self.wait != 1 and 's' or '')}
else:
self.detail = detail or self.default_detail
self.data = {'detail': detail or self.default_detail}
class ConfigurationError(Exception):

View File

@ -0,0 +1,43 @@
"""Default handlers, configurable in settings"""
from rest_framework.response import Response
from rest_framework import exceptions
from rest_framework import status
from django.core.exceptions import PermissionDenied
from django.http import Http404
def handle_exception(view_instance, exc):
"""
Default exception handler for APIView.
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, exceptions.Throttled):
# Throttle wait header
view_instance.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = view_instance.get_authenticate_header(
view_instance.request)
if auth_header:
view_instance.headers['WWW-Authenticate'] = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
if isinstance(exc, exceptions.APIException):
return Response(exc.data, status=exc.status_code,
exception=True)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND,
exception=True)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
status=status.HTTP_403_FORBIDDEN,
exception=True)
raise

View File

@ -7,7 +7,7 @@ which allows mixin classes to be composed in interesting ways.
from __future__ import unicode_literals
from django.http import Http404
from rest_framework import status
from rest_framework import status, exceptions
from rest_framework.response import Response
from rest_framework.request import clone_request
import warnings
@ -55,7 +55,7 @@ class CreateModelMixin(object):
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
raise exceptions.DeserializeError(serializer.errors)
def get_success_headers(self, data):
try:
@ -132,7 +132,7 @@ class UpdateModelMixin(object):
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
raise exceptions.DeserializeError(serializer.errors)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True

View File

@ -59,6 +59,9 @@ DEFAULTS = {
'rest_framework.pagination.PaginationSerializer',
'DEFAULT_FILTER_BACKENDS': (),
# Exception handling
'DEFAULT_EXCEPTION_HANDLER': 'rest_framework.handlers.handle_exception',
# Throttling
'DEFAULT_THROTTLE_RATES': {
'user': None,
@ -117,6 +120,7 @@ IMPORT_STRINGS = (
'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
'DEFAULT_EXCEPTION_HANDLER',
)

View File

@ -158,6 +158,19 @@ class TestRootView(TestCase):
created = self.objects.get(id=4)
self.assertEqual(created.text, 'foobar')
def test_post_wrong_data(self):
"""
POST requests with wrong JSON data should raise HTTP 400
"""
content = {'id': 999, 'wrongtext': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
with self.assertNumQueries(0):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn('text', response.data)
self.assertEqual(response.data['text'], ['This field is required.'])
class TestInstanceView(TestCase):
def setUp(self):
@ -303,6 +316,35 @@ class TestInstanceView(TestCase):
updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar')
def test_put_wrong_data(self):
"""
PUT requests with wrong JSON data should raise HTTP 400
"""
content = {'id': 999, 'wrongtext': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn('text', response.data)
self.assertEqual(response.data['text'], ['This field is required.'])
def test_patch_wrong_data(self):
"""
PATCH requests with wrong JSON data should raise HTTP 400
"""
content = {'text': 'foobar' * 20} # too long
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn('text', response.data)
self.assertEqual(
response.data['text'],
['Ensure this value has at most 100 characters (it has 120).'])
def test_put_to_deleted_instance(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import copy
import warnings
from django.test import TestCase
from django.test.client import RequestFactory
@ -10,6 +11,7 @@ from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
from rest_framework.exceptions import ParseError
factory = RequestFactory()
@ -19,6 +21,11 @@ class BasicView(APIView):
return Response({'method': 'GET'})
def post(self, request, *args, **kwargs):
if 'raise_400_error' in request.DATA:
raise ParseError('Bad request')
if 'return_400_error' in request.DATA:
return Response({'detail': 'Bad request'},
status=status.HTTP_400_BAD_REQUEST)
return Response({'method': 'POST', 'data': request.DATA})
@ -26,12 +33,13 @@ class BasicView(APIView):
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
elif request.method == 'POST':
return {'method': 'POST', 'data': request.DATA}
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.DATA}
elif request.method == 'PATCH':
return {'method': 'PATCH', 'data': request.DATA}
if request.method in ('POST', 'PUT', 'PATCH'):
if 'raise_400_error' in request.DATA:
raise ParseError('Bad request')
if 'return_400_error' in request.DATA:
return Response({'detail': 'Bad request'},
status=status.HTTP_400_BAD_REQUEST)
return {'method': request.method, 'data': request.DATA}
def sanitise_json_error(error_dict):
@ -73,6 +81,29 @@ class ClassBasedViewIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
def test_raise_400_error(self):
request = factory.post('/', '{"raise_400_error": true}',
content_type='application/json')
response = self.view(request)
expected = {
'detail': 'Bad request'
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
def test_return_400_error(self):
request = factory.post('/', '{"return_400_error": true}',
content_type='application/json')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
response = self.view(request)
self.assertEqual(len(w), 1)
expected = {
'detail': 'Bad request'
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
class FunctionBasedViewIntegrationTests(TestCase):
def setUp(self):
@ -101,3 +132,26 @@ class FunctionBasedViewIntegrationTests(TestCase):
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
def test_raise_400_error(self):
request = factory.post('/', '{"raise_400_error": true}',
content_type='application/json')
response = self.view(request)
expected = {
'detail': 'Bad request'
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
def test_return_400_error(self):
request = factory.post('/', '{"return_400_error": true}',
content_type='application/json')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
response = self.view(request)
self.assertEqual(len(w), 1)
expected = {
'detail': 'Bad request'
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)

View File

@ -3,8 +3,9 @@ Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404, HttpResponse
import warnings
from django.http import HttpResponse
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
@ -24,6 +25,7 @@ class APIView(View):
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
handle_exception = api_settings.DEFAULT_EXCEPTION_HANDLER
@classmethod
def as_view(cls, **initkwargs):
@ -263,39 +265,6 @@ class APIView(View):
return response
def handle_exception(self, exc):
"""
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, exceptions.Throttled):
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
self.headers['WWW-Authenticate'] = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail},
status=exc.status_code,
exception=True)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND,
exception=True)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
status=status.HTTP_403_FORBIDDEN,
exception=True)
raise
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
@ -322,6 +291,11 @@ class APIView(View):
response = handler(request, *args, **kwargs)
if response.status_code >= 400:
warnings.warn('Status code >= 400 returned. You should raise'
' `rest_framework.exceptions.APIException`'
' instead.', stacklevel=2)
except Exception as exc:
response = self.handle_exception(exc)