diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 597aeed22..57299a5e7 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,22 @@ class AuthTokenSerializer(serializers.Serializer): if user: if not user.is_active: msg = _('User account is disabled.') - raise exceptions.ValidationError(msg) + raise exceptions.ValidationError( + msg, + error_code='authorization' + ) else: msg = _('Unable to log in with provided credentials.') - raise exceptions.ValidationError(msg) + raise exceptions.ValidationError( + msg, + error_code='authorization' + ) else: msg = _('Must include "username" and "password".') - raise exceptions.ValidationError(msg) + raise exceptions.ValidationError( + msg, + error_code='authorization' + ) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index f587d10fd..fef09fd18 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext from rest_framework import status +from rest_framework.settings import api_settings def _force_text_recursive(data): @@ -51,6 +52,33 @@ class APIException(Exception): return self.detail +def build_error_from_django_validation_error(exc_info): + code = exc_info.code or 'invalid' + return [ + build_error(msg, error_code=code) for msg in exc_info.messages + ] + + +def build_error(detail, error_code=None): + assert not isinstance(detail, dict) and not isinstance(detail, list), ( + 'Use `build_error` only with single error messages. Dictionaries and ' + 'lists should be passed directly to ValidationError.' + ) + + if api_settings.REQUIRE_ERROR_CODES: + assert error_code is not None, ( + 'The `error_code` argument is required for single errors. Strict ' + 'checking of error_code is enabled with REQUIRE_ERROR_CODES ' + 'settings key.' + ) + + return api_settings.ERROR_BUILDER(detail, error_code) + + +def default_error_builder(detail, error_code=None): + return detail + + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -61,12 +89,21 @@ class APIException(Exception): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + def __init__(self, detail, error_code=None): # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): - detail = [detail] + detail = [build_error(detail, error_code=error_code)] + else: + if api_settings.REQUIRE_ERROR_CODES: + assert error_code is None, ( + 'The `error_code` argument must not be set for compound ' + 'errors. Strict checking of error_code is enabled with ' + 'REQUIRE_ERROR_CODES settings key.' + ) + self.detail = _force_text_recursive(detail) + self.error_code = error_code def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3ca7d682e..53ca0a2c0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -25,7 +25,9 @@ from rest_framework.compat import ( MinValueValidator, OrderedDict, URLValidator, duration_string, parse_duration, unicode_repr, unicode_to_repr ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -401,7 +403,9 @@ class Field(object): raise errors.extend(exc.detail) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend( + build_error_from_django_validation_error(exc) + ) if errors: raise ValidationError(errors) @@ -439,7 +443,7 @@ class Field(object): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, error_code=key) @property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 80b7c6d6d..5fdc81209 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.db.models.fields import FieldDoesNotExist from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.utils import model_meta @@ -276,7 +277,8 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: + exceptions.build_error_from_django_validation_error(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -398,8 +400,9 @@ class Serializer(BaseSerializer): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + error = exceptions.build_error(message, error_code='invalid') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = OrderedDict() @@ -416,7 +419,9 @@ class Serializer(BaseSerializer): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = ( + exceptions.build_error_from_django_validation_error(exc.messages) + ) except SkipField: pass else: @@ -534,7 +539,9 @@ class ListSerializer(BaseSerializer): value = self.validate(value) assert value is not None, '.validate() should return the validated data' except (ValidationError, DjangoValidationError) as exc: - raise ValidationError(detail=get_validation_error_detail(exc)) + raise ValidationError( + detail=get_validation_error_detail(exc) + ) return value @@ -549,8 +556,9 @@ class ListSerializer(BaseSerializer): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + error = exceptions.build_error(message, error_code='not_a_list') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = [] diff --git a/rest_framework/settings.py b/rest_framework/settings.py index e20e51287..573e39060 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -85,6 +85,8 @@ DEFAULTS = { # Exception handling 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', 'NON_FIELD_ERRORS_KEY': 'non_field_errors', + 'REQUIRE_ERROR_CODES': False, + 'ERROR_BUILDER': 'rest_framework.exceptions.default_error_builder', # Testing 'TEST_REQUEST_RENDERER_CLASSES': ( @@ -138,6 +140,7 @@ IMPORT_STRINGS = ( 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', + 'ERROR_BUILDER', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', diff --git a/rest_framework/validators.py b/rest_framework/validators.py index a1771a92e..9b59c0ac6 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -11,7 +11,7 @@ from __future__ import unicode_literals from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ValidationError, build_error from rest_framework.utils.representation import smart_repr @@ -60,7 +60,7 @@ class UniqueValidator(object): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if queryset.exists(): - raise ValidationError(self.message) + raise ValidationError(self.message, error_code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -101,7 +101,10 @@ class UniqueTogetherValidator(object): return missing = dict([ - (field_name, self.missing_message) + ( + field_name, + build_error(self.missing_message, error_code='required') + ) for field_name in self.fields if field_name not in attrs ]) @@ -147,7 +150,8 @@ class UniqueTogetherValidator(object): ] if None not in checked_values and queryset.exists(): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + raise ValidationError(self.message.format(field_names=field_names), + error_code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -185,7 +189,10 @@ class BaseUniqueForValidator(object): 'required' state on the fields they are applied to. """ missing = dict([ - (field_name, self.missing_message) + ( + field_name, + build_error(self.missing_message, error_code='required') + ) for field_name in [self.field, self.date_field] if field_name not in attrs ]) @@ -211,7 +218,8 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + error = build_error(message, error_code='unique') + raise ValidationError({self.field: error}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_fields.py b/tests/test_fields.py index 76e6d9d60..ff438a7b2 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1178,7 +1178,10 @@ class TestFieldFieldWithName(FieldValues): # call into it's regular validation, or require PIL for testing. class FailImageValidation(object): def to_python(self, value): - raise serializers.ValidationError(self.error_messages['invalid_image']) + raise serializers.ValidationError( + self.error_messages['invalid_image'], + error_code='invalid_image' + ) class PassImageValidation(object): diff --git a/tests/test_serializer.py b/tests/test_serializer.py index c18cbb584..59b62431a 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -7,6 +7,7 @@ import pytest from rest_framework import serializers from rest_framework.compat import unicode_repr +from rest_framework.fields import DjangoValidationError from .utils import MockObject @@ -59,7 +60,10 @@ class TestValidateMethod: integer = serializers.IntegerField() def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + error_code='test' + ) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) assert not serializer.is_valid() @@ -299,3 +303,10 @@ class TestCacheSerializerData: pickled = pickle.dumps(serializer.data) data = pickle.loads(pickled) assert data == {'field1': 'a', 'field2': 'b'} + + +class TestGetValidationErrorDetail: + def test_get_validation_error_detail_converts_django_errors(self): + exc = DjangoValidationError("Missing field.", code='required') + detail = serializers.get_validation_error_detail(exc) + assert detail == {'non_field_errors': ['Missing field.']} diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index e9234d8f7..9a2df1908 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -280,7 +280,10 @@ class TestListSerializerClass: def test_list_serializer_class_validate(self): class CustomListSerializer(serializers.ListSerializer): def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + error_code='test' + ) class TestSerializer(serializers.Serializer): class Meta: diff --git a/tests/test_validation.py b/tests/test_validation.py index 46e36f5d8..78d6e317d 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -41,7 +41,8 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer): def validate_renamed(self, value): if len(value) < 3: - raise serializers.ValidationError('Minimum 3 characters.') + raise serializers.ValidationError('Minimum 3 characters.', + error_code='min_length') return value class Meta: diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 000000000..bb494dc30 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,56 @@ +import pytest +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.exceptions import build_error +from rest_framework.settings import api_settings + + +class TestMandatoryErrorCodeArgument(TestCase): + """ + If mandatory error code is enabled in settings, it should prevent throwing + ValidationError without the code set. + """ + def setUp(self): + self.DEFAULT_REQUIRE_ERROR_CODES = api_settings.REQUIRE_ERROR_CODES + api_settings.REQUIRE_ERROR_CODES = True + + def tearDown(self): + api_settings.REQUIRE_ERROR_CODES = self.DEFAULT_REQUIRE_ERROR_CODES + + def test_validation_error_requires_code_for_simple_messages(self): + with pytest.raises(AssertionError): + serializers.ValidationError("") + + def test_validation_error_requires_no_code_for_structured_errors(self): + """ + ValidationError can hold a list or dictionary of simple errors, in + which case the code is no longer meaningful and should not be + specified. + """ + with pytest.raises(AssertionError): + serializers.ValidationError([], error_code='min_value') + + with pytest.raises(AssertionError): + serializers.ValidationError({}, error_code='min_value') + + def test_validation_error_stores_error_code(self): + exception = serializers.ValidationError("", error_code='min_value') + assert exception.error_code == 'min_value' + + +class TestCustomErrorBuilder(TestCase): + def setUp(self): + self.DEFAULT_ERROR_BUILDER = api_settings.ERROR_BUILDER + + def error_builder(message, error_code): + return (message, error_code, "customized") + + api_settings.ERROR_BUILDER = error_builder + + def tearDown(self): + api_settings.ERROR_BUILDER = self.DEFAULT_ERROR_BUILDER + + def test_class_based_view_exception_handler(self): + error = build_error("Too many characters", error_code="max_length") + assert error == ("Too many characters", "max_length", "customized")