diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 90d3bd96e..b91a8454f 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -21,13 +21,13 @@ class AuthTokenSerializer(serializers.Serializer): # (Assuming the default `ModelBackend` authentication backend.) if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 29afaffe0..994a942e4 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -10,6 +10,7 @@ import math from django.utils import six from django.utils.encoding import force_text +from django.utils.functional import Promise from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext @@ -37,7 +38,19 @@ def _force_text_recursive(data): if isinstance(data, ReturnDict): return ReturnDict(ret, serializer=data.serializer) return ret - return force_text(data) + + text = force_text(data) + code = getattr(data, 'code', 'invalid') + return ErrorMessage(text, code) + + +class ErrorMessage(six.text_type): + code = None + + def __new__(cls, string, code=None): + self = super(ErrorMessage, cls).__new__(cls, string) + self.code = code + return self class APIException(Exception): @@ -68,7 +81,14 @@ class APIException(Exception): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + def __init__(self, detail, code=None): + if code is not None: + assert isinstance(detail, six.string_types + (Promise,)), ( + "When providing a 'code', the detail must be a string argument. " + "Use 'ErrorMessage' to set the code for a composite ValidationError" + ) + detail = ErrorMessage(detail, code) + # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7f8391b8a..85d582ab2 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -34,7 +34,7 @@ from rest_framework import ISO_8601 from rest_framework.compat import ( get_remote_field, unicode_repr, unicode_to_repr, value_from_object ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ErrorMessage, ValidationError from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -224,6 +224,18 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None): yield Option(value='n/a', display_text=cutoff_text, disabled=True) +def get_error_messages(exc_info): + """ + Given a Django ValidationError, return a list of ErrorMessage, + with the `code` populated. + """ + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ErrorMessage(msg, code=code) + for msg in exc_info.messages + ] + + class CreateOnlyDefault(object): """ This class may be used to provide default values that are only used @@ -525,7 +537,7 @@ class Field(object): raise errors.extend(exc.detail) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend(get_error_messages(exc)) if errors: raise ValidationError(errors) @@ -563,7 +575,7 @@ class Field(object): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 7e99d40b3..c568ab265 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -300,7 +300,7 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: get_error_messages(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -423,7 +423,7 @@ class Serializer(BaseSerializer): datatype=type(data).__name__ ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')] }) ret = OrderedDict() @@ -580,13 +580,13 @@ class ListSerializer(BaseSerializer): input_type=type(data).__name__ ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')] }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 84af0b9d5..386aff9c4 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.db import DataError from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ErrorMessage, ValidationError from rest_framework.utils.representation import smart_repr @@ -80,7 +80,7 @@ class UniqueValidator(object): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if qs_exists(queryset): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -120,13 +120,13 @@ class UniqueTogetherValidator(object): if self.instance is not None: return - missing = { - field_name: self.missing_message + missing_items = { + field_name: ErrorMessage(self.missing_message, code='required') for field_name in self.fields if field_name not in attrs } - if missing: - raise ValidationError(missing) + if missing_items: + raise ValidationError(missing_items) def filter_queryset(self, attrs, queryset): """ @@ -167,7 +167,8 @@ class UniqueTogetherValidator(object): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + message = self.message.format(field_names=field_names) + raise ValidationError(message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -204,13 +205,13 @@ class BaseUniqueForValidator(object): The `UniqueForValidator` classes always force an implied 'required' state on the fields they are applied to. """ - missing = { - field_name: self.missing_message + missing_items = { + field_name: ErrorMessage(self.missing_message, code='required') for field_name in [self.field, self.date_field] if field_name not in attrs } - if missing: - raise ValidationError(missing) + if missing_items: + raise ValidationError(missing_items) def filter_queryset(self, attrs, queryset): raise NotImplementedError('`filter_queryset` must be implemented.') @@ -231,7 +232,9 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + raise ValidationError({ + self.field: ErrorMessage(message, code='unique') + }) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index cec050eb8..ffecf241a 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals from django.test import TestCase from django.utils.translation import ugettext_lazy as _ -from rest_framework.exceptions import _force_text_recursive +from rest_framework.exceptions import ErrorMessage, _force_text_recursive class ExceptionTestCase(TestCase): @@ -12,10 +12,10 @@ class ExceptionTestCase(TestCase): s = "sfdsfggiuytraetfdlklj" self.assertEqual(_force_text_recursive(_(s)), s) - self.assertEqual(type(_force_text_recursive(_(s))), type(s)) + assert isinstance(_force_text_recursive(_(s)), ErrorMessage) self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s) - self.assertEqual(type(_force_text_recursive({'a': _(s)})['a']), type(s)) + assert isinstance(_force_text_recursive({'a': _(s)})['a'], ErrorMessage) self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s) - self.assertEqual(type(_force_text_recursive([[_(s)]])[0][0]), type(s)) + assert isinstance(_force_text_recursive([[_(s)]])[0][0], ErrorMessage)