Introduce ValidationErrorMessage

`ValidationErrorMessage` is a string-like object that holds a code
attribute.

The code attribute has been removed from ValidationError to be able
to maintain better backward compatibility.

What this means is that `ValidationError` can accept either a regular
string or a `ValidationErrorMessage` for its `detail` attribute.
This commit is contained in:
Jonathan Liuti 2015-12-16 19:09:03 +01:00
parent c7351b3832
commit 42f4c5549d
8 changed files with 127 additions and 70 deletions

View File

@ -2,6 +2,7 @@ from django.contrib.auth import authenticate
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from rest_framework.exceptions import ValidationErrorMessage
class AuthTokenSerializer(serializers.Serializer):
@ -19,20 +20,23 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)
attrs['user'] = user

View File

@ -61,7 +61,7 @@ class APIException(Exception):
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationError(msg, code=code)
ValidationErrorMessage(msg, code=code)
for msg in exc_info.messages
]
@ -73,20 +73,26 @@ def build_error_from_django_validation_error(exc_info):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationErrorMessage(six.text_type):
code = None
def __new__(cls, string, code=None, *args, **kwargs):
self = super(ValidationErrorMessage, cls).__new__(
cls, string, *args, **kwargs)
self.code = code
return self
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail, code=None):
def __init__(self, detail):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)):
assert code is None, (
'The `code` argument must not be set for compound errors.')
self.detail = detail
self.code = code
self.detail = _force_text_recursive(detail)
def __str__(self):
return six.text_type(self.detail)

View File

@ -32,7 +32,8 @@ from rest_framework.compat import (
unicode_to_repr
)
from rest_framework.exceptions import (
ValidationError, build_error_from_django_validation_error
ValidationError, ValidationErrorMessage,
build_error_from_django_validation_error
)
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation
@ -503,7 +504,7 @@ class Field(object):
# attempting to accumulate a list of errors.
if isinstance(exc.detail, dict):
raise
errors.append(ValidationError(exc.detail, code=exc.code))
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(
build_error_from_django_validation_error(exc)
@ -545,7 +546,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
message_string = msg.format(**kwargs)
raise ValidationError(message_string, code=key)
raise ValidationError(ValidationErrorMessage(message_string, code=key))
@cached_property
def root(self):

View File

@ -38,6 +38,7 @@ class Response(SimpleTemplateResponse):
'`.error`. representation.'
)
raise AssertionError(msg)
self.data = data
self.template_name = template_name
self.exception = exception

View File

@ -219,14 +219,7 @@ class BaseSerializer(Field):
self._errors = {}
if self._errors and raise_exception:
return_errors = None
if isinstance(self._errors, list):
return_errors = ReturnList(self._errors, serializer=self)
elif isinstance(self._errors, dict):
return_errors = ReturnDict(self._errors, serializer=self)
raise ValidationError(return_errors)
raise ValidationError(self.errors)
return not bool(self._errors)
@property
@ -250,42 +243,12 @@ class BaseSerializer(Field):
self._data = self.get_initial()
return self._data
def _transform_to_legacy_errors(self, errors_to_transform):
# Do not mutate `errors_to_transform` here.
errors = ReturnDict(serializer=self)
for field_name, values in errors_to_transform.items():
if isinstance(values, list):
errors[field_name] = values
continue
if isinstance(values.detail, list):
errors[field_name] = []
for value in values.detail:
if isinstance(value, ValidationError):
errors[field_name].extend(value.detail)
elif isinstance(value, list):
errors[field_name].extend(value)
else:
errors[field_name].append(value)
elif isinstance(values.detail, dict):
errors[field_name] = {}
for sub_field_name, value in values.detail.items():
errors[field_name][sub_field_name] = []
for validation_error in value:
errors[field_name][sub_field_name].extend(validation_error.detail)
return errors
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
if isinstance(self._errors, list):
return map(self._transform_to_legacy_errors, self._errors)
else:
return self._transform_to_legacy_errors(self._errors)
return self._errors
@property
def validated_data(self):
@ -460,7 +423,7 @@ class Serializer(BaseSerializer):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
error = ValidationError(message, code='invalid')
error = ValidationErrorMessage(message, code='invalid')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [error]
})
@ -477,7 +440,7 @@ class Serializer(BaseSerializer):
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
@ -616,7 +579,7 @@ class ListSerializer(BaseSerializer):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
error = ValidationError(
error = ValidationErrorMessage(
message,
code='not_a_list'
)
@ -625,9 +588,11 @@ class ListSerializer(BaseSerializer):
})
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
message = ValidationErrorMessage(
self.error_messages['empty'],
code='empty_not_allowed')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')]
api_settings.NON_FIELD_ERRORS_KEY: [message]
})
ret = []

View File

@ -11,7 +11,7 @@ from __future__ import unicode_literals
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ValidationError, ValidationErrorMessage
from rest_framework.utils.representation import smart_repr
@ -60,7 +60,8 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if queryset.exists():
raise ValidationError(self.message, code='unique')
raise ValidationError(ValidationErrorMessage(self.message,
code='unique'))
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
@ -101,9 +102,10 @@ class UniqueTogetherValidator(object):
return
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in self.fields
if field_name not in attrs
}
@ -149,8 +151,11 @@ class UniqueTogetherValidator(object):
]
if None not in checked_values and queryset.exists():
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names),
code='unique')
raise ValidationError(
ValidationErrorMessage(
self.message.format(field_names=field_names),
code='unique')
)
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -188,7 +193,7 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to.
"""
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field]
@ -216,8 +221,9 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if queryset.exists():
message = self.message.format(date_field=self.date_field)
error = ValidationError(message, code='unique')
raise ValidationError({self.field: error})
raise ValidationError({
self.field: ValidationErrorMessage(message, code='unique'),
})
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -70,7 +70,7 @@ def exception_handler(exc, context):
headers['Retry-After'] = '%d' % exc.wait
if isinstance(exc.detail, (list, dict)):
data = exc.detail.serializer.errors
data = exc.detail
else:
data = {'detail': exc.detail}

View File

@ -0,0 +1,74 @@
from django.test import TestCase
from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for error in errors:
return_errors[field_name].append({
'code': error.code,
'message': error
})
return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)