This commit is contained in:
Jonathan Liuti 2016-10-05 12:09:28 +00:00 committed by GitHub
commit a516fa087b
6 changed files with 147 additions and 19 deletions

View File

@ -18,13 +18,20 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization')
else: else:
msg = _('Unable to log in with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization')
else: else:
msg = _('Must include "username" and "password".') msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization')
attrs['user'] = user attrs['user'] = user
return attrs return attrs

View File

@ -58,6 +58,14 @@ class APIException(Exception):
return self.detail return self.detail
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationErrorMessage(msg, code=code)
for msg in exc_info.messages
]
# The recommended style for using `ValidationError` is to keep it namespaced # The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's # under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example: # built in `ValidationError`. For example:
@ -65,10 +73,25 @@ class APIException(Exception):
# from rest_framework import serializers # from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid') # raise serializers.ValidationError('Value was invalid')
class ValidationErrorMessage(six.text_type):
code = None
def __new__(cls, string, code=None, *args, **kwargs):
self = super(ValidationErrorMessage, cls).__new__(
cls, string, *args, **kwargs)
self.code = code
return self
class ValidationError(APIException): class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail): def __init__(self, detail, code=None):
# If code is there, this means we are dealing with a message.
if code and not isinstance(detail, ValidationErrorMessage):
detail = ValidationErrorMessage(detail, code=code)
# For validation errors the 'detail' key is always required. # For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already. # The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list): if not isinstance(detail, dict) and not isinstance(detail, list):

View File

@ -34,7 +34,9 @@ from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
get_remote_field, unicode_repr, unicode_to_repr, value_from_object get_remote_field, unicode_repr, unicode_to_repr, value_from_object
) )
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import (
ValidationError, build_error_from_django_validation_error
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation from rest_framework.utils import html, humanize_datetime, representation
@ -511,7 +513,9 @@ class Field(object):
raise raise
errors.extend(exc.detail) errors.extend(exc.detail)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(exc.messages) errors.extend(
build_error_from_django_validation_error(exc)
)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
@ -549,7 +553,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
message_string = msg.format(**kwargs) message_string = msg.format(**kwargs)
raise ValidationError(message_string) raise ValidationError(message_string, code=key)
@cached_property @cached_property
def root(self): def root(self):

View File

@ -22,8 +22,10 @@ from django.db.models.fields import FieldDoesNotExist
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.exceptions import ValidationErrorMessage
from rest_framework.utils import model_meta from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import ( from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -220,7 +222,6 @@ class BaseSerializer(Field):
if self._errors and raise_exception: if self._errors and raise_exception:
raise ValidationError(self.errors) raise ValidationError(self.errors)
return not bool(self._errors) return not bool(self._errors)
@property @property
@ -301,7 +302,8 @@ def get_validation_error_detail(exc):
# exception class as well for simpler compat. # exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate() # Eg. Calling Model.clean() explicitly inside Serializer.validate()
return { return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) api_settings.NON_FIELD_ERRORS_KEY:
exceptions.build_error_from_django_validation_error(exc)
} }
elif isinstance(exc.detail, dict): elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
@ -423,8 +425,9 @@ class Serializer(BaseSerializer):
message = self.error_messages['invalid'].format( message = self.error_messages['invalid'].format(
datatype=type(data).__name__ datatype=type(data).__name__
) )
error = ValidationErrorMessage(message, code='invalid')
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [error]
}) })
ret = OrderedDict() ret = OrderedDict()
@ -441,7 +444,9 @@ class Serializer(BaseSerializer):
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = exc.detail errors[field.field_name] = exc.detail
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages) errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
)
except SkipField: except SkipField:
pass pass
else: else:
@ -580,12 +585,18 @@ class ListSerializer(BaseSerializer):
message = self.error_messages['not_a_list'].format( message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__ input_type=type(data).__name__
) )
error = ValidationErrorMessage(
message,
code='not_a_list'
)
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [error]
}) })
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty'] message = ValidationErrorMessage(
self.error_messages['empty'],
code='empty_not_allowed')
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [message]
}) })

View File

@ -12,7 +12,7 @@ from django.db import DataError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError, ValidationErrorMessage
from rest_framework.utils.representation import smart_repr from rest_framework.utils.representation import smart_repr
@ -80,7 +80,7 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset) queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset) queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset): if qs_exists(queryset):
raise ValidationError(self.message) raise ValidationError(self.message, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % ( return unicode_to_repr('<%s(queryset=%s)>' % (
@ -121,7 +121,10 @@ class UniqueTogetherValidator(object):
return return
missing = { missing = {
field_name: self.missing_message field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in self.fields for field_name in self.fields
if field_name not in attrs if field_name not in attrs
} }
@ -167,7 +170,9 @@ class UniqueTogetherValidator(object):
] ]
if None not in checked_values and qs_exists(queryset): if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields) field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names)) raise ValidationError(
self.message.format(field_names=field_names),
code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -205,7 +210,9 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to. 'required' state on the fields they are applied to.
""" """
missing = { missing = {
field_name: self.missing_message field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field] for field_name in [self.field, self.date_field]
if field_name not in attrs if field_name not in attrs
} }
@ -231,7 +238,9 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message}) raise ValidationError({
self.field: ValidationErrorMessage(message, code='unique'),
})
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -0,0 +1,74 @@
from django.test import TestCase
from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for error in errors:
return_errors[field_name].append({
'code': error.code,
'message': error
})
return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)