mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Introduce ValidationErrorMessage
`ValidationErrorMessage` is a string-like object that holds a code attribute. The code attribute has been removed from ValidationError to be able to maintain better backward compatibility. `ValidationErrorMessage` is abstracted in `ValidationError`'s constructor
This commit is contained in:
parent
df0d814665
commit
2bf6ee47f3
|
@ -20,20 +20,18 @@ class AuthTokenSerializer(serializers.Serializer):
|
|||
msg = _('User account is disabled.')
|
||||
raise serializers.ValidationError(
|
||||
msg,
|
||||
code='authorization'
|
||||
)
|
||||
code='authorization')
|
||||
else:
|
||||
msg = _('Unable to log in with provided credentials.')
|
||||
raise serializers.ValidationError(
|
||||
msg,
|
||||
code='authorization'
|
||||
)
|
||||
code='authorization')
|
||||
|
||||
else:
|
||||
msg = _('Must include "username" and "password".')
|
||||
raise serializers.ValidationError(
|
||||
msg,
|
||||
code='authorization'
|
||||
)
|
||||
code='authorization')
|
||||
|
||||
attrs['user'] = user
|
||||
return attrs
|
||||
|
|
|
@ -61,7 +61,7 @@ class APIException(Exception):
|
|||
def build_error_from_django_validation_error(exc_info):
|
||||
code = getattr(exc_info, 'code', None) or 'invalid'
|
||||
return [
|
||||
ValidationError(msg, code=code)
|
||||
ValidationErrorMessage(msg, code=code)
|
||||
for msg in exc_info.messages
|
||||
]
|
||||
|
||||
|
@ -73,20 +73,30 @@ def build_error_from_django_validation_error(exc_info):
|
|||
# from rest_framework import serializers
|
||||
# raise serializers.ValidationError('Value was invalid')
|
||||
|
||||
class ValidationErrorMessage(six.text_type):
|
||||
code = None
|
||||
|
||||
def __new__(cls, string, code=None, *args, **kwargs):
|
||||
self = super(ValidationErrorMessage, cls).__new__(
|
||||
cls, string, *args, **kwargs)
|
||||
|
||||
self.code = code
|
||||
return self
|
||||
|
||||
|
||||
class ValidationError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def __init__(self, detail, code=None):
|
||||
# If code is there, this means we are dealing with a message.
|
||||
if code and not isinstance(detail, ValidationErrorMessage):
|
||||
detail = ValidationErrorMessage(detail, code=code)
|
||||
|
||||
# For validation errors the 'detail' key is always required.
|
||||
# The details should always be coerced to a list if not already.
|
||||
if not isinstance(detail, dict) and not isinstance(detail, list):
|
||||
detail = [detail]
|
||||
elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)):
|
||||
assert code is None, (
|
||||
'The `code` argument must not be set for compound errors.')
|
||||
|
||||
self.detail = detail
|
||||
self.code = code
|
||||
self.detail = _force_text_recursive(detail)
|
||||
|
||||
def __str__(self):
|
||||
return six.text_type(self.detail)
|
||||
|
|
|
@ -509,7 +509,7 @@ class Field(object):
|
|||
# attempting to accumulate a list of errors.
|
||||
if isinstance(exc.detail, dict):
|
||||
raise
|
||||
errors.append(ValidationError(exc.detail, code=exc.code))
|
||||
errors.extend(exc.detail)
|
||||
except DjangoValidationError as exc:
|
||||
errors.extend(
|
||||
build_error_from_django_validation_error(exc)
|
||||
|
|
|
@ -38,6 +38,7 @@ class Response(SimpleTemplateResponse):
|
|||
'`.error`. representation.'
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
self.data = data
|
||||
self.template_name = template_name
|
||||
self.exception = exception
|
||||
|
|
|
@ -25,6 +25,7 @@ from django.utils.translation import ugettext_lazy as _
|
|||
from rest_framework import exceptions
|
||||
from rest_framework.compat import JSONField as ModelJSONField
|
||||
from rest_framework.compat import postgres_fields, unicode_to_repr
|
||||
from rest_framework.exceptions import ValidationErrorMessage
|
||||
from rest_framework.utils import model_meta
|
||||
from rest_framework.utils.field_mapping import (
|
||||
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
|
||||
|
@ -220,14 +221,7 @@ class BaseSerializer(Field):
|
|||
self._errors = {}
|
||||
|
||||
if self._errors and raise_exception:
|
||||
return_errors = None
|
||||
if isinstance(self._errors, list):
|
||||
return_errors = ReturnList(self._errors, serializer=self)
|
||||
elif isinstance(self._errors, dict):
|
||||
return_errors = ReturnDict(self._errors, serializer=self)
|
||||
|
||||
raise ValidationError(return_errors)
|
||||
|
||||
raise ValidationError(self.errors)
|
||||
return not bool(self._errors)
|
||||
|
||||
@property
|
||||
|
@ -251,42 +245,12 @@ class BaseSerializer(Field):
|
|||
self._data = self.get_initial()
|
||||
return self._data
|
||||
|
||||
def _transform_to_legacy_errors(self, errors_to_transform):
|
||||
# Do not mutate `errors_to_transform` here.
|
||||
errors = ReturnDict(serializer=self)
|
||||
for field_name, values in errors_to_transform.items():
|
||||
if isinstance(values, list):
|
||||
errors[field_name] = values
|
||||
continue
|
||||
|
||||
if isinstance(values.detail, list):
|
||||
errors[field_name] = []
|
||||
for value in values.detail:
|
||||
if isinstance(value, ValidationError):
|
||||
errors[field_name].extend(value.detail)
|
||||
elif isinstance(value, list):
|
||||
errors[field_name].extend(value)
|
||||
else:
|
||||
errors[field_name].append(value)
|
||||
|
||||
elif isinstance(values.detail, dict):
|
||||
errors[field_name] = {}
|
||||
for sub_field_name, value in values.detail.items():
|
||||
errors[field_name][sub_field_name] = []
|
||||
for validation_error in value:
|
||||
errors[field_name][sub_field_name].extend(validation_error.detail)
|
||||
return errors
|
||||
|
||||
@property
|
||||
def errors(self):
|
||||
if not hasattr(self, '_errors'):
|
||||
msg = 'You must call `.is_valid()` before accessing `.errors`.'
|
||||
raise AssertionError(msg)
|
||||
|
||||
if isinstance(self._errors, list):
|
||||
return map(self._transform_to_legacy_errors, self._errors)
|
||||
else:
|
||||
return self._transform_to_legacy_errors(self._errors)
|
||||
return self._errors
|
||||
|
||||
@property
|
||||
def validated_data(self):
|
||||
|
@ -461,7 +425,7 @@ class Serializer(BaseSerializer):
|
|||
message = self.error_messages['invalid'].format(
|
||||
datatype=type(data).__name__
|
||||
)
|
||||
error = ValidationError(message, code='invalid')
|
||||
error = ValidationErrorMessage(message, code='invalid')
|
||||
raise ValidationError({
|
||||
api_settings.NON_FIELD_ERRORS_KEY: [error]
|
||||
})
|
||||
|
@ -478,7 +442,7 @@ class Serializer(BaseSerializer):
|
|||
if validate_method is not None:
|
||||
validated_value = validate_method(validated_value)
|
||||
except ValidationError as exc:
|
||||
errors[field.field_name] = exc
|
||||
errors[field.field_name] = exc.detail
|
||||
except DjangoValidationError as exc:
|
||||
errors[field.field_name] = (
|
||||
exceptions.build_error_from_django_validation_error(exc)
|
||||
|
@ -621,7 +585,7 @@ class ListSerializer(BaseSerializer):
|
|||
message = self.error_messages['not_a_list'].format(
|
||||
input_type=type(data).__name__
|
||||
)
|
||||
error = ValidationError(
|
||||
error = ValidationErrorMessage(
|
||||
message,
|
||||
code='not_a_list'
|
||||
)
|
||||
|
@ -630,9 +594,11 @@ class ListSerializer(BaseSerializer):
|
|||
})
|
||||
|
||||
if not self.allow_empty and len(data) == 0:
|
||||
message = self.error_messages['empty']
|
||||
message = ValidationErrorMessage(
|
||||
self.error_messages['empty'],
|
||||
code='empty_not_allowed')
|
||||
raise ValidationError({
|
||||
api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')]
|
||||
api_settings.NON_FIELD_ERRORS_KEY: [message]
|
||||
})
|
||||
|
||||
ret = []
|
||||
|
|
|
@ -12,7 +12,7 @@ from django.db import DataError
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import unicode_to_repr
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.exceptions import ValidationError, ValidationErrorMessage
|
||||
from rest_framework.utils.representation import smart_repr
|
||||
|
||||
|
||||
|
@ -120,9 +120,10 @@ class UniqueTogetherValidator(object):
|
|||
return
|
||||
|
||||
missing = {
|
||||
field_name: ValidationError(
|
||||
field_name: ValidationErrorMessage(
|
||||
self.missing_message,
|
||||
code='required')
|
||||
|
||||
for field_name in self.fields
|
||||
if field_name not in attrs
|
||||
}
|
||||
|
@ -168,7 +169,8 @@ class UniqueTogetherValidator(object):
|
|||
]
|
||||
if None not in checked_values and qs_exists(queryset):
|
||||
field_names = ', '.join(self.fields)
|
||||
raise ValidationError(self.message.format(field_names=field_names),
|
||||
raise ValidationError(
|
||||
self.message.format(field_names=field_names),
|
||||
code='unique')
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -207,7 +209,7 @@ class BaseUniqueForValidator(object):
|
|||
'required' state on the fields they are applied to.
|
||||
"""
|
||||
missing = {
|
||||
field_name: ValidationError(
|
||||
field_name: ValidationErrorMessage(
|
||||
self.missing_message,
|
||||
code='required')
|
||||
for field_name in [self.field, self.date_field]
|
||||
|
@ -235,8 +237,9 @@ class BaseUniqueForValidator(object):
|
|||
queryset = self.exclude_current_instance(attrs, queryset)
|
||||
if qs_exists(queryset):
|
||||
message = self.message.format(date_field=self.date_field)
|
||||
error = ValidationError(message, code='unique')
|
||||
raise ValidationError({self.field: error})
|
||||
raise ValidationError({
|
||||
self.field: ValidationErrorMessage(message, code='unique'),
|
||||
})
|
||||
|
||||
def __repr__(self):
|
||||
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
|
||||
|
|
|
@ -71,7 +71,7 @@ def exception_handler(exc, context):
|
|||
headers['Retry-After'] = '%d' % exc.wait
|
||||
|
||||
if isinstance(exc.detail, (list, dict)):
|
||||
data = exc.detail.serializer.errors
|
||||
data = exc.detail
|
||||
else:
|
||||
data = {'detail': exc.detail}
|
||||
|
||||
|
|
74
tests/test_validation_error.py
Normal file
74
tests/test_validation_error.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
from django.test import TestCase
|
||||
|
||||
from rest_framework import serializers, status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
char = serializers.CharField()
|
||||
integer = serializers.IntegerField()
|
||||
|
||||
|
||||
class ErrorView(APIView):
|
||||
def get(self, request, *args, **kwargs):
|
||||
ExampleSerializer(data={}).is_valid(raise_exception=True)
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
def error_view(request):
|
||||
ExampleSerializer(data={}).is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class TestValidationErrorWithCode(TestCase):
|
||||
def setUp(self):
|
||||
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
|
||||
|
||||
def exception_handler(exc, request):
|
||||
return_errors = {}
|
||||
for field_name, errors in exc.detail.items():
|
||||
return_errors[field_name] = []
|
||||
for error in errors:
|
||||
return_errors[field_name].append({
|
||||
'code': error.code,
|
||||
'message': error
|
||||
})
|
||||
|
||||
return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
api_settings.EXCEPTION_HANDLER = exception_handler
|
||||
|
||||
self.expected_response_data = {
|
||||
'char': [{
|
||||
'message': 'This field is required.',
|
||||
'code': 'required',
|
||||
}],
|
||||
'integer': [{
|
||||
'message': 'This field is required.',
|
||||
'code': 'required'
|
||||
}],
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
|
||||
|
||||
def test_class_based_view_exception_handler(self):
|
||||
view = ErrorView.as_view()
|
||||
|
||||
request = factory.get('/', content_type='application/json')
|
||||
response = view(request)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.data, self.expected_response_data)
|
||||
|
||||
def test_function_based_view_exception_handler(self):
|
||||
view = error_view
|
||||
|
||||
request = factory.get('/', content_type='application/json')
|
||||
response = view(request)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.data, self.expected_response_data)
|
Loading…
Reference in New Issue
Block a user