Error codes (#4550)

Add error codes to `APIException`
This commit is contained in:
Tom Christie 2016-10-11 10:25:21 +01:00 committed by GitHub
parent 5f6ceee7a5
commit a3802504a0
9 changed files with 316 additions and 91 deletions

View File

@ -98,7 +98,7 @@ Note that the exception handler will only be called for responses generated by r
The **base class** for all exceptions raised inside an `APIView` class or `@api_view`.
To provide a custom exception, subclass `APIException` and set the `.status_code` and `.default_detail` properties on the class.
To provide a custom exception, subclass `APIException` and set the `.status_code`, `.default_detail`, and `default_code` attributes on the class.
For example, if your API relies on a third party service that may sometimes be unreachable, you might want to implement an exception for the "503 Service Unavailable" HTTP response code. You could do this like so:
@ -107,10 +107,42 @@ For example, if your API relies on a third party service that may sometimes be u
class ServiceUnavailable(APIException):
status_code = 503
default_detail = 'Service temporarily unavailable, try again later.'
default_code = 'service_unavailable'
#### Inspecting API exceptions
There are a number of different properties available for inspecting the status
of an API exception. You can use these to build custom exception handling
for your project.
The available attributes and methods are:
* `.detail` - Return the textual description of the error.
* `.get_codes()` - Return the code identifier of the error.
* `.full_details()` - Return both the textual description and the code identifier.
In most cases the error detail will be a simple item:
>>> print(exc.detail)
You do not have permission to perform this action.
>>> print(exc.get_codes())
permission_denied
>>> print(exc.full_details())
{'message':'You do not have permission to perform this action.','code':'permission_denied'}
In the case of validation errors the error detail will be either a list or
dictionary of items:
>>> print(exc.detail)
{"name":"This field is required.","age":"A valid integer is required."}
>>> print(exc.get_codes())
{"name":"required","age":"invalid"}
>>> print(exc.get_full_details())
{"name":{"message":"This field is required.","code":"required"},"age":{"message":"A valid integer is required.","code":"invalid"}}
## ParseError
**Signature:** `ParseError(detail=None)`
**Signature:** `ParseError(detail=None, code=None)`
Raised if the request contains malformed data when accessing `request.data`.
@ -118,7 +150,7 @@ By default this exception results in a response with the HTTP status code "400 B
## AuthenticationFailed
**Signature:** `AuthenticationFailed(detail=None)`
**Signature:** `AuthenticationFailed(detail=None, code=None)`
Raised when an incoming request includes incorrect authentication.
@ -126,7 +158,7 @@ By default this exception results in a response with the HTTP status code "401 U
## NotAuthenticated
**Signature:** `NotAuthenticated(detail=None)`
**Signature:** `NotAuthenticated(detail=None, code=None)`
Raised when an unauthenticated request fails the permission checks.
@ -134,7 +166,7 @@ By default this exception results in a response with the HTTP status code "401 U
## PermissionDenied
**Signature:** `PermissionDenied(detail=None)`
**Signature:** `PermissionDenied(detail=None, code=None)`
Raised when an authenticated request fails the permission checks.
@ -142,7 +174,7 @@ By default this exception results in a response with the HTTP status code "403 F
## NotFound
**Signature:** `NotFound(detail=None)`
**Signature:** `NotFound(detail=None, code=None)`
Raised when a resource does not exists at the given URL. This exception is equivalent to the standard `Http404` Django exception.
@ -150,7 +182,7 @@ By default this exception results in a response with the HTTP status code "404 N
## MethodNotAllowed
**Signature:** `MethodNotAllowed(method, detail=None)`
**Signature:** `MethodNotAllowed(method, detail=None, code=None)`
Raised when an incoming request occurs that does not map to a handler method on the view.
@ -158,7 +190,7 @@ By default this exception results in a response with the HTTP status code "405 M
## NotAcceptable
**Signature:** `NotAcceptable(detail=None)`
**Signature:** `NotAcceptable(detail=None, code=None)`
Raised when an incoming request occurs with an `Accept` header that cannot be satisfied by any of the available renderers.
@ -166,7 +198,7 @@ By default this exception results in a response with the HTTP status code "406 N
## UnsupportedMediaType
**Signature:** `UnsupportedMediaType(media_type, detail=None)`
**Signature:** `UnsupportedMediaType(media_type, detail=None, code=None)`
Raised if there are no parsers that can handle the content type of the request data when accessing `request.data`.
@ -174,7 +206,7 @@ By default this exception results in a response with the HTTP status code "415 U
## Throttled
**Signature:** `Throttled(wait=None, detail=None)`
**Signature:** `Throttled(wait=None, detail=None, code=None)`
Raised when an incoming request fails the throttling checks.
@ -182,7 +214,7 @@ By default this exception results in a response with the HTTP status code "429 T
## ValidationError
**Signature:** `ValidationError(detail)`
**Signature:** `ValidationError(detail, code=None)`
The `ValidationError` exception is slightly different from the other `APIException` classes:

View File

@ -21,13 +21,13 @@ class AuthTokenSerializer(serializers.Serializer):
# (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
attrs['user'] = user
return attrs

View File

@ -17,27 +17,61 @@ from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _force_text_recursive(data):
def _get_error_details(data, default_code=None):
"""
Descend into a nested data structure, forcing any
lazy translation strings into plain text.
lazy translation strings or strings into `ErrorDetail`.
"""
if isinstance(data, list):
ret = [
_force_text_recursive(item) for item in data
_get_error_details(item, default_code) for item in data
]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _force_text_recursive(value)
key: _get_error_details(value, default_code)
for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
return force_text(data)
text = force_text(data)
code = getattr(data, 'code', default_code)
return ErrorDetail(text, code)
def _get_codes(detail):
if isinstance(detail, list):
return [_get_codes(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_codes(value) for key, value in detail.items()}
return detail.code
def _get_full_details(detail):
if isinstance(detail, list):
return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()}
return {
'message': detail,
'code': detail.code
}
class ErrorDetail(six.text_type):
"""
A string-like object that can additionally
"""
code = None
def __new__(cls, string, code=None):
self = super(ErrorDetail, cls).__new__(cls, string)
self.code = code
return self
class APIException(Exception):
@ -47,16 +81,35 @@ class APIException(Exception):
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.')
default_code = 'error'
def __init__(self, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
self.detail = force_text(self.default_detail)
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
self.detail = _get_error_details(detail, code)
def __str__(self):
return self.detail
def get_codes(self):
"""
Return only the code part of the error details.
Eg. {"name": ["required"]}
"""
return _get_codes(self.detail)
def get_full_details(self):
"""
Return both the message & code parts of the error details.
Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
"""
return _get_full_details(self.detail)
# The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's
@ -67,13 +120,21 @@ class APIException(Exception):
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid input.')
default_code = 'invalid'
def __init__(self, detail):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
def __init__(self, detail, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
# For validation failures, we may collect may errors together, so the
# details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _force_text_recursive(detail)
self.detail = _get_error_details(detail, code)
def __str__(self):
return six.text_type(self.detail)
@ -82,62 +143,63 @@ class ValidationError(APIException):
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Malformed request.')
default_code = 'parse_error'
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Incorrect authentication credentials.')
default_code = 'authentication_failed'
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Authentication credentials were not provided.')
default_code = 'not_authenticated'
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = _('You do not have permission to perform this action.')
default_code = 'permission_denied'
class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND
default_detail = _('Not found.')
default_code = 'not_found'
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.')
default_code = 'method_not_allowed'
def __init__(self, method, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
self.detail = force_text(self.default_detail).format(method=method)
def __init__(self, method, detail=None, code=None):
if detail is None:
detail = force_text(self.default_detail).format(method=method)
super(MethodNotAllowed, self).__init__(detail, code)
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = _('Could not satisfy the request Accept header.')
default_code = 'not_acceptable'
def __init__(self, detail=None, available_renderers=None):
if detail is not None:
self.detail = force_text(detail)
else:
self.detail = force_text(self.default_detail)
def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers
super(NotAcceptable, self).__init__(detail, code)
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.')
default_code = 'unsupported_media_type'
def __init__(self, media_type, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
self.detail = force_text(self.default_detail).format(
media_type=media_type
)
def __init__(self, media_type, detail=None, code=None):
if detail is None:
detail = force_text(self.default_detail).format(media_type=media_type)
super(UnsupportedMediaType, self).__init__(detail, code)
class Throttled(APIException):
@ -145,12 +207,10 @@ class Throttled(APIException):
default_detail = _('Request was throttled.')
extra_detail_singular = 'Expected available in {wait} second.'
extra_detail_plural = 'Expected available in {wait} seconds.'
default_code = 'throttled'
def __init__(self, wait=None, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
self.detail = force_text(self.default_detail)
def __init__(self, wait=None, detail=None, code=None):
super(Throttled, self).__init__(detail, code)
if wait is None:
self.wait = None

View File

@ -34,7 +34,7 @@ from rest_framework import ISO_8601
from rest_framework.compat import (
get_remote_field, unicode_repr, unicode_to_repr, value_from_object
)
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation
@ -224,6 +224,18 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None):
yield Option(value='n/a', display_text=cutoff_text, disabled=True)
def get_error_detail(exc_info):
"""
Given a Django ValidationError, return a list of ErrorDetail,
with the `code` populated.
"""
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ErrorDetail(msg, code=code)
for msg in exc_info.messages
]
class CreateOnlyDefault(object):
"""
This class may be used to provide default values that are only used
@ -525,7 +537,7 @@ class Field(object):
raise
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(exc.messages)
errors.extend(get_error_detail(exc))
if errors:
raise ValidationError(errors)
@ -563,7 +575,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
message_string = msg.format(**kwargs)
raise ValidationError(message_string)
raise ValidationError(message_string, code=key)
@cached_property
def root(self):

View File

@ -291,32 +291,29 @@ class SerializerMetaclass(type):
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
def get_validation_error_detail(exc):
def as_serializer_error(exc):
assert isinstance(exc, (ValidationError, DjangoValidationError))
if isinstance(exc, DjangoValidationError):
# Normally you should raise `serializers.ValidationError`
# inside your codebase, but we handle Django's validation
# exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate()
return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages)
}
elif isinstance(exc.detail, dict):
detail = get_error_detail(exc)
else:
detail = exc.detail
if isinstance(detail, dict):
# If errors may be a dict we use the standard {key: list of values}.
# Here we ensure that all the values are *lists* of errors.
return {
key: value if isinstance(value, (list, dict)) else [value]
for key, value in exc.detail.items()
for key, value in detail.items()
}
elif isinstance(exc.detail, list):
elif isinstance(detail, list):
# Errors raised as a list are non-field errors.
return {
api_settings.NON_FIELD_ERRORS_KEY: exc.detail
api_settings.NON_FIELD_ERRORS_KEY: detail
}
# Errors raised as a string are non-field errors.
return {
api_settings.NON_FIELD_ERRORS_KEY: [exc.detail]
api_settings.NON_FIELD_ERRORS_KEY: [detail]
}
@ -410,7 +407,7 @@ class Serializer(BaseSerializer):
value = self.validate(value)
assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=get_validation_error_detail(exc))
raise ValidationError(detail=as_serializer_error(exc))
return value
@ -424,7 +421,7 @@ class Serializer(BaseSerializer):
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
})
}, code='invalid')
ret = OrderedDict()
errors = OrderedDict()
@ -440,7 +437,7 @@ class Serializer(BaseSerializer):
except ValidationError as exc:
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages)
errors[field.field_name] = get_error_detail(exc)
except SkipField:
pass
else:
@ -564,7 +561,7 @@ class ListSerializer(BaseSerializer):
value = self.validate(value)
assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=get_validation_error_detail(exc))
raise ValidationError(detail=as_serializer_error(exc))
return value
@ -581,13 +578,13 @@ class ListSerializer(BaseSerializer):
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
})
}, code='not_a_list')
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
})
}, code='empty')
ret = []
errors = []

View File

@ -80,7 +80,7 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset):
raise ValidationError(self.message)
raise ValidationError(self.message, code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
@ -120,13 +120,13 @@ class UniqueTogetherValidator(object):
if self.instance is not None:
return
missing = {
missing_items = {
field_name: self.missing_message
for field_name in self.fields
if field_name not in attrs
}
if missing:
raise ValidationError(missing)
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
"""
@ -167,7 +167,8 @@ class UniqueTogetherValidator(object):
]
if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names))
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -204,13 +205,13 @@ class BaseUniqueForValidator(object):
The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to.
"""
missing = {
missing_items = {
field_name: self.missing_message
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
if missing:
raise ValidationError(missing)
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
@ -231,7 +232,9 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message})
raise ValidationError({
self.field: message
}, code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -3,19 +3,39 @@ from __future__ import unicode_literals
from django.test import TestCase
from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import _force_text_recursive
from rest_framework.exceptions import ErrorDetail, _get_error_details
class ExceptionTestCase(TestCase):
def test_force_text_recursive(self):
def test_get_error_details(self):
s = "sfdsfggiuytraetfdlklj"
self.assertEqual(_force_text_recursive(_(s)), s)
self.assertEqual(type(_force_text_recursive(_(s))), type(s))
example = "string"
lazy_example = _(example)
self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s)
self.assertEqual(type(_force_text_recursive({'a': _(s)})['a']), type(s))
self.assertEqual(
_get_error_details(lazy_example),
example
)
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s)
self.assertEqual(type(_force_text_recursive([[_(s)]])[0][0]), type(s))
self.assertEqual(
_get_error_details({'nested': lazy_example})['nested'],
example
)
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
self.assertEqual(
_get_error_details([[lazy_example]])[0][0],
example
)
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)

View File

@ -60,7 +60,7 @@ class TestNestedValidationError(TestCase):
}
})
self.assertEqual(serializers.get_validation_error_detail(e), {
self.assertEqual(serializers.as_serializer_error(e), {
'nested': {
'field': ['error'],
}

View File

@ -0,0 +1,101 @@
from django.test import TestCase
from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
class TestValidationErrorWithFullDetails(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
data = exc.get_full_details()
return Response(data, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
class TestValidationErrorWithCodes(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
data = exc.get_codes()
return Response(data, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': ['required'],
'integer': ['required'],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)