mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-30 18:09:59 +03:00
Merge 74643b08b6
into f1251e8c58
This commit is contained in:
commit
7b91afb616
|
@ -2,6 +2,7 @@ from rest_framework.views import APIView
|
|||
from rest_framework import status
|
||||
from rest_framework import parsers
|
||||
from rest_framework import renderers
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
|
@ -20,7 +21,7 @@ class ObtainAuthToken(APIView):
|
|||
if serializer.is_valid():
|
||||
token, created = Token.objects.get_or_create(user=serializer.object['user'])
|
||||
return Response({'token': token.key})
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
raise exceptions.TokenAuthenticationError(serializer.errors)
|
||||
|
||||
|
||||
obtain_auth_token = ObtainAuthToken.as_view()
|
||||
|
|
|
@ -11,49 +11,57 @@ from rest_framework import status
|
|||
class APIException(Exception):
|
||||
"""
|
||||
Base class for REST framework exceptions.
|
||||
Subclasses should provide `.status_code` and `.detail` properties.
|
||||
Subclasses should provide `.status_code` and `.data` properties.
|
||||
|
||||
The `.data` is a dictionary that usually contains just one
|
||||
field: "detail". However, some exception classes may override
|
||||
it.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, detail=None):
|
||||
self.data = {'detail': detail or self.default_detail}
|
||||
|
||||
|
||||
class ParseError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_detail = 'Malformed request.'
|
||||
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail or self.default_detail
|
||||
|
||||
class DeserializeError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def __init__(self, errors):
|
||||
self.data = dict(errors)
|
||||
|
||||
|
||||
class TokenAuthenticationError(DeserializeError):
|
||||
"""Raised when incorrect data is posted during Token Authentication."""
|
||||
# TODO: Change status code to HTTP_401
|
||||
# TODO: Make data look like {'detail': 'Reason of failure'}
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationFailed(APIException):
|
||||
status_code = status.HTTP_401_UNAUTHORIZED
|
||||
default_detail = 'Incorrect authentication credentials.'
|
||||
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail or self.default_detail
|
||||
|
||||
|
||||
class NotAuthenticated(APIException):
|
||||
status_code = status.HTTP_401_UNAUTHORIZED
|
||||
default_detail = 'Authentication credentials were not provided.'
|
||||
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail or self.default_detail
|
||||
|
||||
|
||||
class PermissionDenied(APIException):
|
||||
status_code = status.HTTP_403_FORBIDDEN
|
||||
default_detail = 'You do not have permission to perform this action.'
|
||||
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail or self.default_detail
|
||||
|
||||
|
||||
class MethodNotAllowed(APIException):
|
||||
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
default_detail = "Method '%s' not allowed."
|
||||
|
||||
def __init__(self, method, detail=None):
|
||||
self.detail = (detail or self.default_detail) % method
|
||||
self.data = {'detail': (detail or self.default_detail) % method}
|
||||
|
||||
|
||||
class NotAcceptable(APIException):
|
||||
|
@ -61,7 +69,9 @@ class NotAcceptable(APIException):
|
|||
default_detail = "Could not satisfy the request's Accept header"
|
||||
|
||||
def __init__(self, detail=None, available_renderers=None):
|
||||
self.detail = detail or self.default_detail
|
||||
super(NotAcceptable, self).__init__(detail)
|
||||
# TODO: self.available_renderers not used anywhere
|
||||
# across the code
|
||||
self.available_renderers = available_renderers
|
||||
|
||||
|
||||
|
@ -70,7 +80,7 @@ class UnsupportedMediaType(APIException):
|
|||
default_detail = "Unsupported media type '%s' in request."
|
||||
|
||||
def __init__(self, media_type, detail=None):
|
||||
self.detail = (detail or self.default_detail) % media_type
|
||||
self.data = {'detail': (detail or self.default_detail) % media_type}
|
||||
|
||||
|
||||
class Throttled(APIException):
|
||||
|
@ -83,9 +93,10 @@ class Throttled(APIException):
|
|||
self.wait = wait and math.ceil(wait) or None
|
||||
if wait is not None:
|
||||
format = detail or self.default_detail + self.extra_detail
|
||||
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
|
||||
self.data = {'detail':
|
||||
format % (self.wait, self.wait != 1 and 's' or '')}
|
||||
else:
|
||||
self.detail = detail or self.default_detail
|
||||
self.data = {'detail': detail or self.default_detail}
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
|
|
43
rest_framework/handlers.py
Normal file
43
rest_framework/handlers.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
"""Default handlers, configurable in settings"""
|
||||
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import exceptions
|
||||
from rest_framework import status
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
|
||||
|
||||
def handle_exception(view_instance, exc):
|
||||
"""
|
||||
Default exception handler for APIView.
|
||||
|
||||
Handle any exception that occurs, by returning an appropriate response,
|
||||
or re-raising the error.
|
||||
"""
|
||||
if isinstance(exc, exceptions.Throttled):
|
||||
# Throttle wait header
|
||||
view_instance.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
|
||||
|
||||
if isinstance(exc, (exceptions.NotAuthenticated,
|
||||
exceptions.AuthenticationFailed)):
|
||||
# WWW-Authenticate header for 401 responses, else coerce to 403
|
||||
auth_header = view_instance.get_authenticate_header(
|
||||
view_instance.request)
|
||||
|
||||
if auth_header:
|
||||
view_instance.headers['WWW-Authenticate'] = auth_header
|
||||
else:
|
||||
exc.status_code = status.HTTP_403_FORBIDDEN
|
||||
|
||||
if isinstance(exc, exceptions.APIException):
|
||||
return Response(exc.data, status=exc.status_code,
|
||||
exception=True)
|
||||
elif isinstance(exc, Http404):
|
||||
return Response({'detail': 'Not found'},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
exception=True)
|
||||
elif isinstance(exc, PermissionDenied):
|
||||
return Response({'detail': 'Permission denied'},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
exception=True)
|
||||
raise
|
|
@ -7,7 +7,7 @@ which allows mixin classes to be composed in interesting ways.
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.http import Http404
|
||||
from rest_framework import status
|
||||
from rest_framework import status, exceptions
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import clone_request
|
||||
import warnings
|
||||
|
@ -55,7 +55,7 @@ class CreateModelMixin(object):
|
|||
return Response(serializer.data, status=status.HTTP_201_CREATED,
|
||||
headers=headers)
|
||||
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
raise exceptions.DeserializeError(serializer.errors)
|
||||
|
||||
def get_success_headers(self, data):
|
||||
try:
|
||||
|
@ -132,7 +132,7 @@ class UpdateModelMixin(object):
|
|||
self.post_save(self.object, created=created)
|
||||
return Response(serializer.data, status=success_status_code)
|
||||
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
raise exceptions.DeserializeError(serializer.errors)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
kwargs['partial'] = True
|
||||
|
|
|
@ -59,6 +59,9 @@ DEFAULTS = {
|
|||
'rest_framework.pagination.PaginationSerializer',
|
||||
'DEFAULT_FILTER_BACKENDS': (),
|
||||
|
||||
# Exception handling
|
||||
'DEFAULT_EXCEPTION_HANDLER': 'rest_framework.handlers.handle_exception',
|
||||
|
||||
# Throttling
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
'user': None,
|
||||
|
@ -117,6 +120,7 @@ IMPORT_STRINGS = (
|
|||
'FILTER_BACKEND',
|
||||
'UNAUTHENTICATED_USER',
|
||||
'UNAUTHENTICATED_TOKEN',
|
||||
'DEFAULT_EXCEPTION_HANDLER',
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -158,6 +158,19 @@ class TestRootView(TestCase):
|
|||
created = self.objects.get(id=4)
|
||||
self.assertEqual(created.text, 'foobar')
|
||||
|
||||
def test_post_wrong_data(self):
|
||||
"""
|
||||
POST requests with wrong JSON data should raise HTTP 400
|
||||
"""
|
||||
content = {'id': 999, 'wrongtext': 'foobar'}
|
||||
request = factory.post('/', json.dumps(content),
|
||||
content_type='application/json')
|
||||
with self.assertNumQueries(0):
|
||||
response = self.view(request, pk=1).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn('text', response.data)
|
||||
self.assertEqual(response.data['text'], ['This field is required.'])
|
||||
|
||||
|
||||
class TestInstanceView(TestCase):
|
||||
def setUp(self):
|
||||
|
@ -303,6 +316,35 @@ class TestInstanceView(TestCase):
|
|||
updated = self.objects.get(id=1)
|
||||
self.assertEqual(updated.text, 'foobar')
|
||||
|
||||
def test_put_wrong_data(self):
|
||||
"""
|
||||
PUT requests with wrong JSON data should raise HTTP 400
|
||||
"""
|
||||
content = {'id': 999, 'wrongtext': 'foobar'}
|
||||
request = factory.put('/1', json.dumps(content),
|
||||
content_type='application/json')
|
||||
with self.assertNumQueries(1):
|
||||
response = self.view(request, pk=1).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn('text', response.data)
|
||||
self.assertEqual(response.data['text'], ['This field is required.'])
|
||||
|
||||
def test_patch_wrong_data(self):
|
||||
"""
|
||||
PATCH requests with wrong JSON data should raise HTTP 400
|
||||
"""
|
||||
content = {'text': 'foobar' * 20} # too long
|
||||
request = factory.patch('/1', json.dumps(content),
|
||||
content_type='application/json')
|
||||
|
||||
with self.assertNumQueries(1):
|
||||
response = self.view(request, pk=1).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn('text', response.data)
|
||||
self.assertEqual(
|
||||
response.data['text'],
|
||||
['Ensure this value has at most 100 characters (it has 120).'])
|
||||
|
||||
def test_put_to_deleted_instance(self):
|
||||
"""
|
||||
PUT requests to RetrieveUpdateDestroyAPIView should create an object
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test.client import RequestFactory
|
||||
|
@ -10,6 +11,7 @@ from rest_framework.decorators import api_view
|
|||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.exceptions import ParseError
|
||||
|
||||
factory = RequestFactory()
|
||||
|
||||
|
@ -19,6 +21,11 @@ class BasicView(APIView):
|
|||
return Response({'method': 'GET'})
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
if 'raise_400_error' in request.DATA:
|
||||
raise ParseError('Bad request')
|
||||
if 'return_400_error' in request.DATA:
|
||||
return Response({'detail': 'Bad request'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
return Response({'method': 'POST', 'data': request.DATA})
|
||||
|
||||
|
||||
|
@ -26,12 +33,13 @@ class BasicView(APIView):
|
|||
def basic_view(request):
|
||||
if request.method == 'GET':
|
||||
return {'method': 'GET'}
|
||||
elif request.method == 'POST':
|
||||
return {'method': 'POST', 'data': request.DATA}
|
||||
elif request.method == 'PUT':
|
||||
return {'method': 'PUT', 'data': request.DATA}
|
||||
elif request.method == 'PATCH':
|
||||
return {'method': 'PATCH', 'data': request.DATA}
|
||||
if request.method in ('POST', 'PUT', 'PATCH'):
|
||||
if 'raise_400_error' in request.DATA:
|
||||
raise ParseError('Bad request')
|
||||
if 'return_400_error' in request.DATA:
|
||||
return Response({'detail': 'Bad request'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
return {'method': request.method, 'data': request.DATA}
|
||||
|
||||
|
||||
def sanitise_json_error(error_dict):
|
||||
|
@ -73,6 +81,29 @@ class ClassBasedViewIntegrationTests(TestCase):
|
|||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
||||
def test_raise_400_error(self):
|
||||
request = factory.post('/', '{"raise_400_error": true}',
|
||||
content_type='application/json')
|
||||
response = self.view(request)
|
||||
expected = {
|
||||
'detail': 'Bad request'
|
||||
}
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
||||
def test_return_400_error(self):
|
||||
request = factory.post('/', '{"return_400_error": true}',
|
||||
content_type='application/json')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
response = self.view(request)
|
||||
self.assertEqual(len(w), 1)
|
||||
expected = {
|
||||
'detail': 'Bad request'
|
||||
}
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
||||
|
||||
class FunctionBasedViewIntegrationTests(TestCase):
|
||||
def setUp(self):
|
||||
|
@ -101,3 +132,26 @@ class FunctionBasedViewIntegrationTests(TestCase):
|
|||
}
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
||||
def test_raise_400_error(self):
|
||||
request = factory.post('/', '{"raise_400_error": true}',
|
||||
content_type='application/json')
|
||||
response = self.view(request)
|
||||
expected = {
|
||||
'detail': 'Bad request'
|
||||
}
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
||||
def test_return_400_error(self):
|
||||
request = factory.post('/', '{"return_400_error": true}',
|
||||
content_type='application/json')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
response = self.view(request)
|
||||
self.assertEqual(len(w), 1)
|
||||
expected = {
|
||||
'detail': 'Bad request'
|
||||
}
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(sanitise_json_error(response.data), expected)
|
||||
|
|
|
@ -3,8 +3,9 @@ Provides an APIView class that is the base of all views in REST framework.
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404, HttpResponse
|
||||
import warnings
|
||||
|
||||
from django.http import HttpResponse
|
||||
from django.utils.datastructures import SortedDict
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status, exceptions
|
||||
|
@ -24,6 +25,7 @@ class APIView(View):
|
|||
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
|
||||
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
|
||||
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
|
||||
handle_exception = api_settings.DEFAULT_EXCEPTION_HANDLER
|
||||
|
||||
@classmethod
|
||||
def as_view(cls, **initkwargs):
|
||||
|
@ -263,39 +265,6 @@ class APIView(View):
|
|||
|
||||
return response
|
||||
|
||||
def handle_exception(self, exc):
|
||||
"""
|
||||
Handle any exception that occurs, by returning an appropriate response,
|
||||
or re-raising the error.
|
||||
"""
|
||||
if isinstance(exc, exceptions.Throttled):
|
||||
# Throttle wait header
|
||||
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
|
||||
|
||||
if isinstance(exc, (exceptions.NotAuthenticated,
|
||||
exceptions.AuthenticationFailed)):
|
||||
# WWW-Authenticate header for 401 responses, else coerce to 403
|
||||
auth_header = self.get_authenticate_header(self.request)
|
||||
|
||||
if auth_header:
|
||||
self.headers['WWW-Authenticate'] = auth_header
|
||||
else:
|
||||
exc.status_code = status.HTTP_403_FORBIDDEN
|
||||
|
||||
if isinstance(exc, exceptions.APIException):
|
||||
return Response({'detail': exc.detail},
|
||||
status=exc.status_code,
|
||||
exception=True)
|
||||
elif isinstance(exc, Http404):
|
||||
return Response({'detail': 'Not found'},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
exception=True)
|
||||
elif isinstance(exc, PermissionDenied):
|
||||
return Response({'detail': 'Permission denied'},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
exception=True)
|
||||
raise
|
||||
|
||||
# Note: session based authentication is explicitly CSRF validated,
|
||||
# all other authentication is CSRF exempt.
|
||||
@csrf_exempt
|
||||
|
@ -322,6 +291,11 @@ class APIView(View):
|
|||
|
||||
response = handler(request, *args, **kwargs)
|
||||
|
||||
if response.status_code >= 400:
|
||||
warnings.warn('Status code >= 400 returned. You should raise'
|
||||
' `rest_framework.exceptions.APIException`'
|
||||
' instead.', stacklevel=2)
|
||||
|
||||
except Exception as exc:
|
||||
response = self.handle_exception(exc)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user