Introduce ValidationErrorMessage

`ValidationErrorMessage` is a string-like object that holds a code
attribute.

The code attribute has been removed from ValidationError to be able
to maintain better backward compatibility.

`ValidationErrorMessage` is abstracted in `ValidationError`'s
constructor
This commit is contained in:
Jonathan Liuti 2015-12-16 19:09:03 +01:00
parent df0d814665
commit 2bf6ee47f3
8 changed files with 118 additions and 66 deletions

View File

@ -20,20 +20,18 @@ class AuthTokenSerializer(serializers.Serializer):
msg = _('User account is disabled.')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')
attrs['user'] = user
return attrs

View File

@ -61,7 +61,7 @@ class APIException(Exception):
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationError(msg, code=code)
ValidationErrorMessage(msg, code=code)
for msg in exc_info.messages
]
@ -73,20 +73,30 @@ def build_error_from_django_validation_error(exc_info):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationErrorMessage(six.text_type):
code = None
def __new__(cls, string, code=None, *args, **kwargs):
self = super(ValidationErrorMessage, cls).__new__(
cls, string, *args, **kwargs)
self.code = code
return self
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail, code=None):
# If code is there, this means we are dealing with a message.
if code and not isinstance(detail, ValidationErrorMessage):
detail = ValidationErrorMessage(detail, code=code)
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)):
assert code is None, (
'The `code` argument must not be set for compound errors.')
self.detail = detail
self.code = code
self.detail = _force_text_recursive(detail)
def __str__(self):
return six.text_type(self.detail)

View File

@ -509,7 +509,7 @@ class Field(object):
# attempting to accumulate a list of errors.
if isinstance(exc.detail, dict):
raise
errors.append(ValidationError(exc.detail, code=exc.code))
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(
build_error_from_django_validation_error(exc)

View File

@ -38,6 +38,7 @@ class Response(SimpleTemplateResponse):
'`.error`. representation.'
)
raise AssertionError(msg)
self.data = data
self.template_name = template_name
self.exception = exception

View File

@ -25,6 +25,7 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions
from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.exceptions import ValidationErrorMessage
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -220,14 +221,7 @@ class BaseSerializer(Field):
self._errors = {}
if self._errors and raise_exception:
return_errors = None
if isinstance(self._errors, list):
return_errors = ReturnList(self._errors, serializer=self)
elif isinstance(self._errors, dict):
return_errors = ReturnDict(self._errors, serializer=self)
raise ValidationError(return_errors)
raise ValidationError(self.errors)
return not bool(self._errors)
@property
@ -251,42 +245,12 @@ class BaseSerializer(Field):
self._data = self.get_initial()
return self._data
def _transform_to_legacy_errors(self, errors_to_transform):
# Do not mutate `errors_to_transform` here.
errors = ReturnDict(serializer=self)
for field_name, values in errors_to_transform.items():
if isinstance(values, list):
errors[field_name] = values
continue
if isinstance(values.detail, list):
errors[field_name] = []
for value in values.detail:
if isinstance(value, ValidationError):
errors[field_name].extend(value.detail)
elif isinstance(value, list):
errors[field_name].extend(value)
else:
errors[field_name].append(value)
elif isinstance(values.detail, dict):
errors[field_name] = {}
for sub_field_name, value in values.detail.items():
errors[field_name][sub_field_name] = []
for validation_error in value:
errors[field_name][sub_field_name].extend(validation_error.detail)
return errors
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
if isinstance(self._errors, list):
return map(self._transform_to_legacy_errors, self._errors)
else:
return self._transform_to_legacy_errors(self._errors)
return self._errors
@property
def validated_data(self):
@ -461,7 +425,7 @@ class Serializer(BaseSerializer):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
error = ValidationError(message, code='invalid')
error = ValidationErrorMessage(message, code='invalid')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [error]
})
@ -478,7 +442,7 @@ class Serializer(BaseSerializer):
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
@ -621,7 +585,7 @@ class ListSerializer(BaseSerializer):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
error = ValidationError(
error = ValidationErrorMessage(
message,
code='not_a_list'
)
@ -630,9 +594,11 @@ class ListSerializer(BaseSerializer):
})
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
message = ValidationErrorMessage(
self.error_messages['empty'],
code='empty_not_allowed')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')]
api_settings.NON_FIELD_ERRORS_KEY: [message]
})
ret = []

View File

@ -12,7 +12,7 @@ from django.db import DataError
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ValidationError, ValidationErrorMessage
from rest_framework.utils.representation import smart_repr
@ -120,9 +120,10 @@ class UniqueTogetherValidator(object):
return
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in self.fields
if field_name not in attrs
}
@ -168,8 +169,9 @@ class UniqueTogetherValidator(object):
]
if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names),
code='unique')
raise ValidationError(
self.message.format(field_names=field_names),
code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -207,7 +209,7 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to.
"""
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field]
@ -235,8 +237,9 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
error = ValidationError(message, code='unique')
raise ValidationError({self.field: error})
raise ValidationError({
self.field: ValidationErrorMessage(message, code='unique'),
})
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -71,7 +71,7 @@ def exception_handler(exc, context):
headers['Retry-After'] = '%d' % exc.wait
if isinstance(exc.detail, (list, dict)):
data = exc.detail.serializer.errors
data = exc.detail
else:
data = {'detail': exc.detail}

View File

@ -0,0 +1,74 @@
from django.test import TestCase
from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for error in errors:
return_errors[field_name].append({
'code': error.code,
'message': error
})
return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)