Rejig ErrorMessages, to prefer code= directly on ValidationError

This commit is contained in:
Tom Christie 2016-10-10 18:04:58 +01:00
parent a6756a1dd9
commit 7943429dab
3 changed files with 20 additions and 28 deletions

View File

@ -10,7 +10,6 @@ import math
from django.utils import six
from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext
@ -18,21 +17,21 @@ from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _force_text_recursive(data):
def _force_text_recursive(data, code=None):
"""
Descend into a nested data structure, forcing any
lazy translation strings into plain text.
lazy translation strings or strings into `ErrorMessage`.
"""
if isinstance(data, list):
ret = [
_force_text_recursive(item) for item in data
_force_text_recursive(item, code) for item in data
]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _force_text_recursive(value)
key: _force_text_recursive(value, code)
for key, value in data.items()
}
if isinstance(data, ReturnDict):
@ -40,7 +39,7 @@ def _force_text_recursive(data):
return ret
text = force_text(data)
code = getattr(data, 'code', 'invalid')
code = getattr(data, 'code', code or 'invalid')
return ErrorMessage(text, code)
@ -82,18 +81,11 @@ class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail, code=None):
if code is not None:
assert isinstance(detail, six.string_types + (Promise,)), (
"When providing a 'code', the detail must be a string argument. "
"Use 'ErrorMessage' to set the code for a composite ValidationError"
)
detail = ErrorMessage(detail, code)
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _force_text_recursive(detail)
self.detail = _force_text_recursive(detail, code=code)
def __str__(self):
return six.text_type(self.detail)

View File

@ -423,8 +423,8 @@ class Serializer(BaseSerializer):
datatype=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid')
ret = OrderedDict()
errors = OrderedDict()
@ -440,7 +440,7 @@ class Serializer(BaseSerializer):
except ValidationError as exc:
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages)
errors[field.field_name] = get_validation_error_detail(exc)
except SkipField:
pass
else:
@ -580,14 +580,14 @@ class ListSerializer(BaseSerializer):
input_type=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='not_a_list')
if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='empty')
ret = []
errors = []

View File

@ -12,7 +12,7 @@ from django.db import DataError
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.exceptions import ValidationError
from rest_framework.utils.representation import smart_repr
@ -121,12 +121,12 @@ class UniqueTogetherValidator(object):
return
missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
field_name: self.missing_message
for field_name in self.fields
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items)
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
"""
@ -206,12 +206,12 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to.
"""
missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
field_name: self.missing_message
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items)
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
@ -233,8 +233,8 @@ class BaseUniqueForValidator(object):
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
self.field: ErrorMessage(message, code='unique')
})
self.field: message
}, code='unique')
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (