diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86..ce6bb9c79 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,22 @@ class AuthTokenSerializer(serializers.Serializer): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 29afaffe0..975a40790 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext from rest_framework import status +from rest_framework.settings import api_settings from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList @@ -58,6 +59,14 @@ class APIException(Exception): return self.detail +def build_error_from_django_validation_error(exc_info): + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ValidationError.build_detail(msg, code=code) + for msg in exc_info.messages + ] + + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -68,12 +77,40 @@ class APIException(Exception): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + @staticmethod + def build_detail(detail, code=None): + """ + Create error's representation. + + This method is a helper that should be used when building compound + ValidationErrors directly (i.e. a whole list at once). Thanks to that + extra call, users have a customization point where they can tune how + much information about an error they want to see in the final output. + """ + if api_settings.REQUIRE_ERROR_CODES: + assert code is not None, ( + 'The `code` argument is required for single errors. ' + 'Strict checking of `code` is enabled with ' + 'REQUIRE_ERROR_CODES settings key.' + ) + + return detail + + def __init__(self, detail, code=None): # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): - detail = [detail] + detail = [self.build_detail(detail, code)] + else: + if api_settings.REQUIRE_ERROR_CODES: + assert code is None, ( + 'The `code` argument must not be set for compound ' + 'errors. Strict checking of `code` is enabled with ' + 'REQUIRE_ERROR_CODES settings key.' + ) + self.detail = _force_text_recursive(detail) + self.code = code def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6d5962c8e..a1981a9b4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -31,7 +31,9 @@ from rest_framework.compat import ( MinValueValidator, duration_string, parse_duration, unicode_repr, unicode_to_repr ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -503,7 +505,9 @@ class Field(object): raise errors.extend(exc.detail) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend( + build_error_from_django_validation_error(exc) + ) if errors: raise ValidationError(errors) @@ -541,7 +545,7 @@ class Field(object): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b95bb7fa6..8b53b6037 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.db.models.fields import FieldDoesNotExist from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr @@ -300,7 +301,8 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: + exceptions.build_error_from_django_validation_error(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -422,8 +424,9 @@ class Serializer(BaseSerializer): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + error = ValidationError.build_detail(message, code='invalid') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = OrderedDict() @@ -440,7 +443,9 @@ class Serializer(BaseSerializer): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = ( + exceptions.build_error_from_django_validation_error(exc) + ) except SkipField: pass else: @@ -560,7 +565,9 @@ class ListSerializer(BaseSerializer): value = self.validate(value) assert value is not None, '.validate() should return the validated data' except (ValidationError, DjangoValidationError) as exc: - raise ValidationError(detail=get_validation_error_detail(exc)) + raise ValidationError( + detail=get_validation_error_detail(exc) + ) return value @@ -575,8 +582,12 @@ class ListSerializer(BaseSerializer): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + error = ValidationError.build_detail( + message, + code='not_a_list' + ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) if not self.allow_empty and len(data) == 0: diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 946b905c6..feedf5d4b 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -85,6 +85,7 @@ DEFAULTS = { # Exception handling 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', 'NON_FIELD_ERRORS_KEY': 'non_field_errors', + 'REQUIRE_ERROR_CODES': False, # Testing 'TEST_REQUEST_RENDERER_CLASSES': ( diff --git a/rest_framework/validators.py b/rest_framework/validators.py index a21f67e60..66e3d4bc1 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -60,7 +60,7 @@ class UniqueValidator(object): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if queryset.exists(): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -101,7 +101,9 @@ class UniqueTogetherValidator(object): return missing = { - field_name: self.missing_message + field_name: ValidationError.build_detail( + self.missing_message, + code='required') for field_name in self.fields if field_name not in attrs } @@ -147,7 +149,8 @@ class UniqueTogetherValidator(object): ] if None not in checked_values and queryset.exists(): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + raise ValidationError(self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -185,7 +188,9 @@ class BaseUniqueForValidator(object): 'required' state on the fields they are applied to. """ missing = { - field_name: self.missing_message + field_name: ValidationError.build_detail( + self.missing_message, + code='required') for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -211,7 +216,8 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + error = ValidationError.build_detail(message, code='unique') + raise ValidationError({self.field: error}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_fields.py b/tests/test_fields.py index 43441c2e7..cbf3d846f 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1399,7 +1399,10 @@ class TestFieldFieldWithName(FieldValues): # call into it's regular validation, or require PIL for testing. class FailImageValidation(object): def to_python(self, value): - raise serializers.ValidationError(self.error_messages['invalid_image']) + raise serializers.ValidationError( + self.error_messages['invalid_image'], + code='invalid_image' + ) class PassImageValidation(object): diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 741c6ab17..a0219527a 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -7,6 +7,7 @@ import pytest from rest_framework import serializers from rest_framework.compat import unicode_repr +from rest_framework.fields import DjangoValidationError from .utils import MockObject @@ -69,7 +70,10 @@ class TestValidateMethod: integer = serializers.IntegerField() def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + code='test' + ) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) assert not serializer.is_valid() @@ -309,3 +313,25 @@ class TestCacheSerializerData: pickled = pickle.dumps(serializer.data) data = pickle.loads(pickled) assert data == {'field1': 'a', 'field2': 'b'} + + +class TestGetValidationErrorDetail: + def test_get_validation_error_detail_converts_django_errors(self): + exc = DjangoValidationError("Missing field.", code='required') + detail = serializers.get_validation_error_detail(exc) + assert detail == {'non_field_errors': ['Missing field.']} + + +class TestCapturingDjangoValidationError: + def test_django_validation_error_on_a_field_is_converted(self): + class ExampleSerializer(serializers.Serializer): + field = serializers.CharField() + + def validate_field(self, value): + raise DjangoValidationError( + 'validation failed' + ) + + serializer = ExampleSerializer(data={'field': 'a'}) + assert not serializer.is_valid() + assert serializer.errors == {'field': ['validation failed']} diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index 607ddba04..81aed3574 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -280,7 +280,10 @@ class TestListSerializerClass: def test_list_serializer_class_validate(self): class CustomListSerializer(serializers.ListSerializer): def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + code='test' + ) class TestSerializer(serializers.Serializer): class Meta: diff --git a/tests/test_validation.py b/tests/test_validation.py index b6f274219..398ca118d 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -41,7 +41,8 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer): def validate_renamed(self, value): if len(value) < 3: - raise serializers.ValidationError('Minimum 3 characters.') + raise serializers.ValidationError('Minimum 3 characters.', + code='min_length') return value class Meta: diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 000000000..5ceb60859 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,38 @@ +import pytest +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.settings import api_settings + + +class TestMandatoryErrorCodeArgument(TestCase): + """ + If mandatory error code is enabled in settings, it should prevent throwing + ValidationError without the code set. + """ + def setUp(self): + self.DEFAULT_REQUIRE_ERROR_CODES = api_settings.REQUIRE_ERROR_CODES + api_settings.REQUIRE_ERROR_CODES = True + + def tearDown(self): + api_settings.REQUIRE_ERROR_CODES = self.DEFAULT_REQUIRE_ERROR_CODES + + def test_validation_error_requires_code_for_simple_messages(self): + with pytest.raises(AssertionError): + serializers.ValidationError("") + + def test_validation_error_requires_no_code_for_structured_errors(self): + """ + ValidationError can hold a list or dictionary of simple errors, in + which case the code is no longer meaningful and should not be + specified. + """ + with pytest.raises(AssertionError): + serializers.ValidationError([], code='min_value') + + with pytest.raises(AssertionError): + serializers.ValidationError({}, code='min_value') + + def test_validation_error_stores_error_code(self): + exception = serializers.ValidationError("", code='min_value') + assert exception.code == 'min_value'