This commit is contained in:
Jonathan Liuti 2016-02-17 17:32:29 +00:00
commit f600787dbf
6 changed files with 171 additions and 24 deletions

View File

@ -18,13 +18,17 @@ class AuthTokenSerializer(serializers.Serializer):
if user:
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
code = 'authorization'
raise serializers.ValidationError(msg, code=code)
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg)
code = 'authorization'
raise serializers.ValidationError(msg, code=code)
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg)
code = 'authorization'
raise serializers.ValidationError(msg, code=code)
attrs['user'] = user
return attrs

View File

@ -7,6 +7,7 @@ In addition Django's built in 403 and 404 exceptions are handled.
from __future__ import unicode_literals
import math
from collections import namedtuple
from django.utils import six
from django.utils.encoding import force_text
@ -58,6 +59,13 @@ class APIException(Exception):
return self.detail
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ErrorDetails(msg, code)
for msg in exc_info.messages
]
# The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example:
@ -65,15 +73,61 @@ class APIException(Exception):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
ErrorDetails = namedtuple('ErrorDetails', ['message', 'code'])
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
code = None
def __init__(self, detail):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _force_text_recursive(detail)
def __init__(self, detail, code=None):
if code:
self.full_details = ErrorDetails(detail, code)
else:
self.full_details = detail
if not isinstance(self.full_details, dict) \
and not isinstance(self.full_details, list):
self.full_details = [self.full_details]
self.full_details = _force_text_recursive(self.full_details)
self.detail = detail
if isinstance(self.full_details, list):
if isinstance(self.full_details, ReturnList):
self.detail = ReturnList(
serializer=self.full_details.serializer)
else:
self.detail = []
for full_detail in self.full_details:
if isinstance(full_detail, ErrorDetails):
self.detail.append(full_detail.message)
elif isinstance(full_detail, dict):
if not full_detail:
self.detail.append(full_detail)
for key, value in full_detail.items():
if isinstance(value, list):
self.detail.append(
{key: [item.message]
if isinstance(item, ErrorDetails)
else [item] for item in value})
elif isinstance(full_detail, list):
self.detail.extend(full_detail)
else:
self.detail.append(full_detail)
elif isinstance(self.full_details, dict):
if isinstance(self.full_details, ReturnDict):
self.detail = ReturnDict(
serializer=self.full_details.serializer)
else:
self.detail = {}
for field_name, full_detail in self.full_details.items():
if isinstance(full_detail, list):
self.detail[field_name] = [
item.message if isinstance(item, ErrorDetails) else item
for item in full_detail
]
else:
self.detail[field_name] = full_detail
def __str__(self):
return six.text_type(self.detail)

View File

@ -31,7 +31,9 @@ from rest_framework.compat import (
MinValueValidator, duration_string, parse_duration, unicode_repr,
unicode_to_repr
)
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import (
ValidationError, build_error_from_django_validation_error
)
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation
@ -501,9 +503,9 @@ class Field(object):
# attempting to accumulate a list of errors.
if isinstance(exc.detail, dict):
raise
errors.extend(exc.detail)
errors.append(exc.full_details)
except DjangoValidationError as exc:
errors.extend(exc.messages)
errors.extend(build_error_from_django_validation_error(exc))
if errors:
raise ValidationError(errors)
@ -541,7 +543,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
message_string = msg.format(**kwargs)
raise ValidationError(message_string)
raise ValidationError(message_string, code=key)
@cached_property
def root(self):

View File

@ -23,6 +23,9 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import DurationField as ModelDurationField
from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.exceptions import (
ErrorDetails, build_error_from_django_validation_error
)
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -299,8 +302,9 @@ def get_validation_error_detail(exc):
# inside your codebase, but we handle Django's validation
# exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate()
error = build_error_from_django_validation_error(exc)
return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages)
api_settings.NON_FIELD_ERRORS_KEY: error
}
elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}.
@ -422,8 +426,9 @@ class Serializer(BaseSerializer):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
code = 'invalid'
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
})
ret = OrderedDict()
@ -440,7 +445,8 @@ class Serializer(BaseSerializer):
except ValidationError as exc:
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages)
error = build_error_from_django_validation_error(exc)
errors[field.field_name] = error
except SkipField:
pass
else:
@ -575,14 +581,16 @@ class ListSerializer(BaseSerializer):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
code = 'not_a_list'
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
})
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
code = 'empty_not_allowed'
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
})
ret = []

View File

@ -11,7 +11,7 @@ from __future__ import unicode_literals
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ErrorDetails, ValidationError
from rest_framework.utils.representation import smart_repr
@ -60,7 +60,7 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if queryset.exists():
raise ValidationError(self.message)
raise ValidationError(self.message, code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
@ -100,8 +100,9 @@ class UniqueTogetherValidator(object):
if self.instance is not None:
return
code = 'required'
missing = {
field_name: self.missing_message
field_name: ErrorDetails(self.missing_message, code)
for field_name in self.fields
if field_name not in attrs
}
@ -147,7 +148,9 @@ class UniqueTogetherValidator(object):
]
if None not in checked_values and queryset.exists():
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names))
message = self.message.format(field_names=field_names)
code = 'unique'
raise ValidationError(ErrorDetails(message, code=code))
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -184,8 +187,9 @@ class BaseUniqueForValidator(object):
The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to.
"""
code = 'required'
missing = {
field_name: self.missing_message
field_name: ErrorDetails(self.missing_message, code)
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
@ -211,7 +215,8 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if queryset.exists():
message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message})
code = 'unique'
raise ValidationError({self.field: ErrorDetails(message, code)})
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -0,0 +1,74 @@
from django.test import TestCase
from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for message, code in errors:
return_errors[field_name].append({
'code': code,
'message': message
})
return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)