Add 'code' to ValidationError

This commit is contained in:
Tom Christie 2016-10-10 16:45:25 +01:00
parent 0dec36eb41
commit f078a04f7e
6 changed files with 63 additions and 28 deletions

View File

@ -21,13 +21,13 @@ class AuthTokenSerializer(serializers.Serializer):
# (Assuming the default `ModelBackend` authentication backend.) # (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg, code='authorization')
else: else:
msg = _('Unable to log in with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg, code='authorization')
else: else:
msg = _('Must include "username" and "password".') msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg, code='authorization')
attrs['user'] = user attrs['user'] = user
return attrs return attrs

View File

@ -10,6 +10,7 @@ import math
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext from django.utils.translation import ungettext
@ -37,7 +38,19 @@ def _force_text_recursive(data):
if isinstance(data, ReturnDict): if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer) return ReturnDict(ret, serializer=data.serializer)
return ret return ret
return force_text(data)
text = force_text(data)
code = getattr(data, 'code', 'invalid')
return ErrorMessage(text, code)
class ErrorMessage(six.text_type):
code = None
def __new__(cls, string, code=None):
self = super(ErrorMessage, cls).__new__(cls, string)
self.code = code
return self
class APIException(Exception): class APIException(Exception):
@ -68,7 +81,14 @@ class APIException(Exception):
class ValidationError(APIException): class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail): def __init__(self, detail, code=None):
if code is not None:
assert isinstance(detail, six.string_types + (Promise,)), (
"When providing a 'code', the detail must be a string argument. "
"Use 'ErrorMessage' to set the code for a composite ValidationError"
)
detail = ErrorMessage(detail, code)
# For validation errors the 'detail' key is always required. # For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already. # The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list): if not isinstance(detail, dict) and not isinstance(detail, list):

View File

@ -34,7 +34,7 @@ from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
get_remote_field, unicode_repr, unicode_to_repr, value_from_object get_remote_field, unicode_repr, unicode_to_repr, value_from_object
) )
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation from rest_framework.utils import html, humanize_datetime, representation
@ -224,6 +224,18 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None):
yield Option(value='n/a', display_text=cutoff_text, disabled=True) yield Option(value='n/a', display_text=cutoff_text, disabled=True)
def get_error_messages(exc_info):
"""
Given a Django ValidationError, return a list of ErrorMessage,
with the `code` populated.
"""
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ErrorMessage(msg, code=code)
for msg in exc_info.messages
]
class CreateOnlyDefault(object): class CreateOnlyDefault(object):
""" """
This class may be used to provide default values that are only used This class may be used to provide default values that are only used
@ -525,7 +537,7 @@ class Field(object):
raise raise
errors.extend(exc.detail) errors.extend(exc.detail)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(exc.messages) errors.extend(get_error_messages(exc))
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
@ -563,7 +575,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
message_string = msg.format(**kwargs) message_string = msg.format(**kwargs)
raise ValidationError(message_string) raise ValidationError(message_string, code=key)
@cached_property @cached_property
def root(self): def root(self):

View File

@ -300,7 +300,7 @@ def get_validation_error_detail(exc):
# exception class as well for simpler compat. # exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate() # Eg. Calling Model.clean() explicitly inside Serializer.validate()
return { return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) api_settings.NON_FIELD_ERRORS_KEY: get_error_messages(exc)
} }
elif isinstance(exc.detail, dict): elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
@ -423,7 +423,7 @@ class Serializer(BaseSerializer):
datatype=type(data).__name__ datatype=type(data).__name__
) )
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')]
}) })
ret = OrderedDict() ret = OrderedDict()
@ -580,13 +580,13 @@ class ListSerializer(BaseSerializer):
input_type=type(data).__name__ input_type=type(data).__name__
) )
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')]
}) })
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty'] message = self.error_messages['empty']
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')]
}) })
ret = [] ret = []

View File

@ -12,7 +12,7 @@ from django.db import DataError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.utils.representation import smart_repr from rest_framework.utils.representation import smart_repr
@ -80,7 +80,7 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset) queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset) queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset): if qs_exists(queryset):
raise ValidationError(self.message) raise ValidationError(self.message, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % ( return unicode_to_repr('<%s(queryset=%s)>' % (
@ -120,13 +120,13 @@ class UniqueTogetherValidator(object):
if self.instance is not None: if self.instance is not None:
return return
missing = { missing_items = {
field_name: self.missing_message field_name: ErrorMessage(self.missing_message, code='required')
for field_name in self.fields for field_name in self.fields
if field_name not in attrs if field_name not in attrs
} }
if missing: if missing_items:
raise ValidationError(missing) raise ValidationError(missing_items)
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
""" """
@ -167,7 +167,8 @@ class UniqueTogetherValidator(object):
] ]
if None not in checked_values and qs_exists(queryset): if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields) field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names)) message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -204,13 +205,13 @@ class BaseUniqueForValidator(object):
The `UniqueFor<Range>Validator` classes always force an implied The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to. 'required' state on the fields they are applied to.
""" """
missing = { missing_items = {
field_name: self.missing_message field_name: ErrorMessage(self.missing_message, code='required')
for field_name in [self.field, self.date_field] for field_name in [self.field, self.date_field]
if field_name not in attrs if field_name not in attrs
} }
if missing: if missing_items:
raise ValidationError(missing) raise ValidationError(missing_items)
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.') raise NotImplementedError('`filter_queryset` must be implemented.')
@ -231,7 +232,9 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message}) raise ValidationError({
self.field: ErrorMessage(message, code='unique')
})
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import _force_text_recursive from rest_framework.exceptions import ErrorMessage, _force_text_recursive
class ExceptionTestCase(TestCase): class ExceptionTestCase(TestCase):
@ -12,10 +12,10 @@ class ExceptionTestCase(TestCase):
s = "sfdsfggiuytraetfdlklj" s = "sfdsfggiuytraetfdlklj"
self.assertEqual(_force_text_recursive(_(s)), s) self.assertEqual(_force_text_recursive(_(s)), s)
self.assertEqual(type(_force_text_recursive(_(s))), type(s)) assert isinstance(_force_text_recursive(_(s)), ErrorMessage)
self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s) self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s)
self.assertEqual(type(_force_text_recursive({'a': _(s)})['a']), type(s)) assert isinstance(_force_text_recursive({'a': _(s)})['a'], ErrorMessage)
self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s) self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s)
self.assertEqual(type(_force_text_recursive([[_(s)]])[0][0]), type(s)) assert isinstance(_force_text_recursive([[_(s)]])[0][0], ErrorMessage)