Rejig ErrorMessages, to prefer code= directly on ValidationError

This commit is contained in:
Tom Christie 2016-10-10 18:04:58 +01:00
parent a6756a1dd9
commit 7943429dab
3 changed files with 20 additions and 28 deletions

View File

@ -10,7 +10,6 @@ import math
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext from django.utils.translation import ungettext
@ -18,21 +17,21 @@ from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _force_text_recursive(data): def _force_text_recursive(data, code=None):
""" """
Descend into a nested data structure, forcing any Descend into a nested data structure, forcing any
lazy translation strings into plain text. lazy translation strings or strings into `ErrorMessage`.
""" """
if isinstance(data, list): if isinstance(data, list):
ret = [ ret = [
_force_text_recursive(item) for item in data _force_text_recursive(item, code) for item in data
] ]
if isinstance(data, ReturnList): if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer) return ReturnList(ret, serializer=data.serializer)
return ret return ret
elif isinstance(data, dict): elif isinstance(data, dict):
ret = { ret = {
key: _force_text_recursive(value) key: _force_text_recursive(value, code)
for key, value in data.items() for key, value in data.items()
} }
if isinstance(data, ReturnDict): if isinstance(data, ReturnDict):
@ -40,7 +39,7 @@ def _force_text_recursive(data):
return ret return ret
text = force_text(data) text = force_text(data)
code = getattr(data, 'code', 'invalid') code = getattr(data, 'code', code or 'invalid')
return ErrorMessage(text, code) return ErrorMessage(text, code)
@ -82,18 +81,11 @@ class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail, code=None): def __init__(self, detail, code=None):
if code is not None:
assert isinstance(detail, six.string_types + (Promise,)), (
"When providing a 'code', the detail must be a string argument. "
"Use 'ErrorMessage' to set the code for a composite ValidationError"
)
detail = ErrorMessage(detail, code)
# For validation errors the 'detail' key is always required. # For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already. # The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list): if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail] detail = [detail]
self.detail = _force_text_recursive(detail) self.detail = _force_text_recursive(detail, code=code)
def __str__(self): def __str__(self):
return six.text_type(self.detail) return six.text_type(self.detail)

View File

@ -423,8 +423,8 @@ class Serializer(BaseSerializer):
datatype=type(data).__name__ datatype=type(data).__name__
) )
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')] api_settings.NON_FIELD_ERRORS_KEY: [message]
}) }, code='invalid')
ret = OrderedDict() ret = OrderedDict()
errors = OrderedDict() errors = OrderedDict()
@ -440,7 +440,7 @@ class Serializer(BaseSerializer):
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = exc.detail errors[field.field_name] = exc.detail
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages) errors[field.field_name] = get_validation_error_detail(exc)
except SkipField: except SkipField:
pass pass
else: else:
@ -580,14 +580,14 @@ class ListSerializer(BaseSerializer):
input_type=type(data).__name__ input_type=type(data).__name__
) )
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')] api_settings.NON_FIELD_ERRORS_KEY: [message]
}) }, code='not_a_list')
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty'] message = self.error_messages['empty']
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')] api_settings.NON_FIELD_ERRORS_KEY: [message]
}) }, code='empty')
ret = [] ret = []
errors = [] errors = []

View File

@ -12,7 +12,7 @@ from django.db import DataError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ErrorMessage, ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.utils.representation import smart_repr from rest_framework.utils.representation import smart_repr
@ -121,12 +121,12 @@ class UniqueTogetherValidator(object):
return return
missing_items = { missing_items = {
field_name: ErrorMessage(self.missing_message, code='required') field_name: self.missing_message
for field_name in self.fields for field_name in self.fields
if field_name not in attrs if field_name not in attrs
} }
if missing_items: if missing_items:
raise ValidationError(missing_items) raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
""" """
@ -206,12 +206,12 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to. 'required' state on the fields they are applied to.
""" """
missing_items = { missing_items = {
field_name: ErrorMessage(self.missing_message, code='required') field_name: self.missing_message
for field_name in [self.field, self.date_field] for field_name in [self.field, self.date_field]
if field_name not in attrs if field_name not in attrs
} }
if missing_items: if missing_items:
raise ValidationError(missing_items) raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.') raise NotImplementedError('`filter_queryset` must be implemented.')
@ -233,8 +233,8 @@ class BaseUniqueForValidator(object):
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({ raise ValidationError({
self.field: ErrorMessage(message, code='unique') self.field: message
}) }, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (