diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 8a295c03e..caf5b64b3 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,22 @@ class AuthTokenSerializer(serializers.Serializer): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 8447a9ded..f8330bf17 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -58,6 +58,14 @@ class APIException(Exception): return self.detail +def build_error_from_django_validation_error(exc_info): + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ValidationError(msg, code=code) + for msg in exc_info.messages + ] + + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -68,12 +76,17 @@ class APIException(Exception): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + def __init__(self, detail, code=None): # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - self.detail = _force_text_recursive(detail) + elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)): + assert code is None, ( + 'The `code` argument must not be set for compound errors.') + + self.detail = detail + self.code = code def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8541bc43a..e9044513b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -31,7 +31,9 @@ from rest_framework.compat import ( MinValueValidator, duration_string, parse_duration, unicode_repr, unicode_to_repr ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -501,9 +503,11 @@ class Field(object): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.extend(exc.detail) + errors.append(ValidationError(exc.detail, code=exc.code)) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend( + build_error_from_django_validation_error(exc) + ) if errors: raise ValidationError(errors) @@ -541,7 +545,7 @@ class Field(object): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index dafaf7de6..d52ebaea9 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -493,6 +493,21 @@ class BrowsableAPIRenderer(BaseRenderer): if hasattr(serializer, 'initial_data'): serializer.is_valid() + # Convert ValidationError to unicode string + # This is mainly a hack to monkey patch the errors and make the form renderer happy... + errors = OrderedDict() + for field_name, values in serializer.errors.items(): + if isinstance(values, list): + errors[field_name] = values + continue + + errors[field_name] = [] + for value in values.detail: + for message in value.detail: + errors[field_name].append(message) + + serializer._errors = errors + form_renderer = self.form_renderer_class() return form_renderer.render( serializer.data, diff --git a/rest_framework/response.py b/rest_framework/response.py index 0e97668eb..9942dd576 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -38,7 +38,6 @@ class Response(SimpleTemplateResponse): '`.error`. representation.' ) raise AssertionError(msg) - self.data = data self.template_name = template_name self.exception = exception diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 99d36a8a5..e3206793e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.db.models.fields import FieldDoesNotExist from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr @@ -300,7 +301,8 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: + exceptions.build_error_from_django_validation_error(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -422,8 +424,9 @@ class Serializer(BaseSerializer): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + error = ValidationError(message, code='invalid') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = OrderedDict() @@ -438,9 +441,11 @@ class Serializer(BaseSerializer): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc.detail + errors[field.field_name] = exc except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = ( + exceptions.build_error_from_django_validation_error(exc) + ) except SkipField: pass else: @@ -575,14 +580,18 @@ class ListSerializer(BaseSerializer): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + error = ValidationError( + message, + code='not_a_list' + ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index a21f67e60..27148cedc 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -60,7 +60,7 @@ class UniqueValidator(object): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if queryset.exists(): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -101,7 +101,9 @@ class UniqueTogetherValidator(object): return missing = { - field_name: self.missing_message + field_name: ValidationError( + self.missing_message, + code='required') for field_name in self.fields if field_name not in attrs } @@ -147,7 +149,8 @@ class UniqueTogetherValidator(object): ] if None not in checked_values and queryset.exists(): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + raise ValidationError(self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -185,7 +188,9 @@ class BaseUniqueForValidator(object): 'required' state on the fields they are applied to. """ missing = { - field_name: self.missing_message + field_name: ValidationError( + self.missing_message, + code='required') for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -211,7 +216,8 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + error = ValidationError(message, code='unique') + raise ValidationError({self.field: error}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/rest_framework/views.py b/rest_framework/views.py index 56e3c4e49..926a3d00e 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,6 +15,7 @@ from django.views.generic import View from rest_framework import exceptions, status from rest_framework.compat import set_rollback +from rest_framework.exceptions import ValidationError, _force_text_recursive from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -69,7 +70,17 @@ def exception_handler(exc, context): if getattr(exc, 'wait', None): headers['Retry-After'] = '%d' % exc.wait - if isinstance(exc.detail, (list, dict)): + if isinstance(exc.detail, list): + data = _force_text_recursive([item.detail if isinstance(item, ValidationError) else item + for item in exc.detai]) + elif isinstance(exc.detail, dict): + for field_name, e in exc.detail.items(): + if hasattr(e, 'detail') and isinstance(e.detail[0], ValidationError): + exc.detail[field_name] = e.detail[0].detail + elif isinstance(e, ValidationError): + exc.detail[field_name] = e.detail + else: + exc.detail[field_name] = e data = exc.detail else: data = {'detail': exc.detail} diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py index f2fac8f0d..0fd8e6f5d 100644 --- a/tests/test_bound_fields.py +++ b/tests/test_bound_fields.py @@ -39,7 +39,8 @@ class TestSimpleBoundField: serializer.is_valid() assert serializer['text'].value == 'x' * 1000 - assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.'] + assert serializer['text'].errors.detail[0].detail == ['Ensure this field has no more than 100 characters.'] + assert serializer['text'].errors.detail[0].code == 'max_length' assert serializer['text'].name == 'text' assert serializer['amount'].value is 123 assert serializer['amount'].errors is None diff --git a/tests/test_fields.py b/tests/test_fields.py index 9cb59f7da..0c1bbfc85 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -10,6 +10,7 @@ from django.utils import timezone import rest_framework from rest_framework import serializers +from rest_framework.exceptions import ValidationError # Tests for field keyword arguments and core functionality. @@ -426,7 +427,13 @@ class FieldValues: for input_value, expected_failure in get_items(self.invalid_inputs): with pytest.raises(serializers.ValidationError) as exc_info: self.field.run_validation(input_value) - assert exc_info.value.detail == expected_failure + + if isinstance(exc_info.value.detail[0], ValidationError): + failure = exc_info.value.detail[0].detail + else: + failure = exc_info.value.detail + + assert failure == expected_failure def test_outputs(self): for output_value, expected_output in get_items(self.outputs): @@ -1393,7 +1400,10 @@ class TestFieldFieldWithName(FieldValues): # call into it's regular validation, or require PIL for testing. class FailImageValidation(object): def to_python(self, value): - raise serializers.ValidationError(self.error_messages['invalid_image']) + raise serializers.ValidationError( + self.error_messages['invalid_image'], + code='invalid_image' + ) class PassImageValidation(object): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 57e540e7a..fb786b01f 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -374,7 +374,7 @@ class TestGenericIPAddressFieldValidation(TestCase): s = TestSerializer(data={'address': 'not an ip address'}) self.assertFalse(s.is_valid()) - self.assertEquals(1, len(s.errors['address']), + self.assertEquals(1, len(s.errors['address'].detail), 'Unexpected number of validation errors: ' '{0}'.format(s.errors)) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index c0642eda2..c2fc7904f 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -244,7 +244,8 @@ class HyperlinkedForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']}) + self.assertEqual(serializer.errors['target'].detail, ['Incorrect type. Expected URL string, received int.']) + self.assertEqual(serializer.errors['target'].code, 'incorrect_type') def test_reverse_foreign_key_update(self): data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} @@ -315,7 +316,8 @@ class HyperlinkedForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + self.assertEqual(serializer.errors['target'].detail, ['This field may not be null.']) + self.assertEqual(serializer.errors['target'].code, 'null') class HyperlinkedNullableForeignKeyTests(TestCase): diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 169f7d9c5..0ad31109d 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -235,7 +235,9 @@ class PKForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) + self.assertEqual(serializer.errors['target'].detail, + ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]) + self.assertEqual(serializer.errors['target'].code, 'incorrect_type') def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} @@ -306,7 +308,8 @@ class PKForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + self.assertEqual(serializer.errors['target'].detail, ['This field may not be null.']) + self.assertEqual(serializer.errors['target'].code, 'null') def test_foreign_key_with_unsaved(self): source = ForeignKeySource(name='source-unsaved') diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 680aee417..fb7910b69 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -104,7 +104,8 @@ class SlugForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + self.assertEqual(serializer.errors['target'].detail, ['Object with name=123 does not exist.']) + self.assertEqual(serializer.errors['target'].code, 'does_not_exist') def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} @@ -176,7 +177,8 @@ class SlugForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + self.assertEqual(serializer.errors['target'].detail, ['This field may not be null.']) + self.assertEqual(serializer.errors['target'].code, 'null') class SlugNullableForeignKeyTests(TestCase): diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 741c6ab17..b2436ce4a 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -7,6 +7,7 @@ import pytest from rest_framework import serializers from rest_framework.compat import unicode_repr +from rest_framework.fields import DjangoValidationError from .utils import MockObject @@ -31,7 +32,8 @@ class TestSerializer: serializer = self.Serializer(data={'char': 'abc'}) assert not serializer.is_valid() assert serializer.validated_data == {} - assert serializer.errors == {'integer': ['This field is required.']} + assert serializer.errors['integer'].detail == ['This field is required.'] + assert serializer.errors['integer'].code == 'required' def test_partial_validation(self): serializer = self.Serializer(data={'char': 'abc'}, partial=True) @@ -69,7 +71,10 @@ class TestValidateMethod: integer = serializers.IntegerField() def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + code='test' + ) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) assert not serializer.is_valid() @@ -309,3 +314,27 @@ class TestCacheSerializerData: pickled = pickle.dumps(serializer.data) data = pickle.loads(pickled) assert data == {'field1': 'a', 'field2': 'b'} + + +class TestGetValidationErrorDetail: + def test_get_validation_error_detail_converts_django_errors(self): + exc = DjangoValidationError("Missing field.", code='required') + detail = serializers.get_validation_error_detail(exc) + assert detail['non_field_errors'][0].detail == ['Missing field.'] + assert detail['non_field_errors'][0].code == 'required' + + +class TestCapturingDjangoValidationError: + def test_django_validation_error_on_a_field_is_converted(self): + class ExampleSerializer(serializers.Serializer): + field = serializers.CharField() + + def validate_field(self, value): + raise DjangoValidationError( + 'validation failed' + ) + + serializer = ExampleSerializer(data={'field': 'a'}) + assert not serializer.is_valid() + assert serializer.errors['field'][0].detail == ['validation failed'] + assert serializer.errors['field'][0].code == 'invalid' diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index 8d7240a7b..b18d6200d 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -67,15 +67,16 @@ class BulkCreateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - expected_errors = [ - {}, - {}, - {'id': ['A valid integer is required.']} - ] serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, expected_errors) + + for idx, error in enumerate(serializer.errors): + if idx < 2: + self.assertEqual(error, {}) + else: + self.assertEqual(error['id'].detail, ['A valid integer is required.']) + self.assertEqual(error['id'].code, 'invalid') def test_invalid_list_datatype(self): """ @@ -87,13 +88,10 @@ class BulkCreateSerializerTests(TestCase): text_type_string = six.text_type.__name__ message = 'Invalid data. Expected a dictionary, but got %s.' % text_type_string - expected_errors = [ - {'non_field_errors': [message]}, - {'non_field_errors': [message]}, - {'non_field_errors': [message]} - ] - self.assertEqual(serializer.errors, expected_errors) + for error in serializer.errors: + self.assertEqual(error['non_field_errors'][0].detail, [message]) + self.assertEqual(error['non_field_errors'][0].code, 'invalid') def test_invalid_single_datatype(self): """ @@ -103,9 +101,9 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} - - self.assertEqual(serializer.errors, expected_errors) + self.assertEqual(serializer.errors['non_field_errors'][0].detail, + ['Expected a list of items but got type "int".']) + self.assertEqual(serializer.errors['non_field_errors'][0].code, 'not_a_list') def test_invalid_single_object(self): """ @@ -120,6 +118,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} + self.assertEqual(serializer.errors['non_field_errors'][0].detail, + ['Expected a list of items but got type "dict".']) - self.assertEqual(serializer.errors, expected_errors) + self.assertEqual(serializer.errors['non_field_errors'][0].code, 'not_a_list') diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index 607ddba04..81aed3574 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -280,7 +280,10 @@ class TestListSerializerClass: def test_list_serializer_class_validate(self): class CustomListSerializer(serializers.ListSerializer): def validate(self, attrs): - raise serializers.ValidationError('Non field error') + raise serializers.ValidationError( + 'Non field error', + code='test' + ) class TestSerializer(serializers.Serializer): class Meta: diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py index aeb092ee0..41d6724ac 100644 --- a/tests/test_serializer_nested.py +++ b/tests/test_serializer_nested.py @@ -113,8 +113,8 @@ class TestNestedSerializerWithMany: assert not serializer.is_valid() - expected_errors = {'not_allow_null': [serializer.error_messages['null']]} - assert serializer.errors == expected_errors + assert serializer.errors['not_allow_null'].detail == [serializer.error_messages['null']] + assert serializer.errors['not_allow_null'].code == 'null' def test_run_the_field_validation_even_if_the_field_is_null(self): class TestSerializer(self.Serializer): @@ -165,5 +165,7 @@ class TestNestedSerializerWithMany: assert not serializer.is_valid() - expected_errors = {'not_allow_empty': {'non_field_errors': [serializers.ListSerializer.default_error_messages['empty']]}} - assert serializer.errors == expected_errors + assert serializer.errors['not_allow_empty'].detail['non_field_errors'][0].detail == \ + [serializers.ListSerializer.default_error_messages['empty']] + + assert serializer.errors['not_allow_empty'].detail['non_field_errors'][0].code == 'empty_not_allowed' diff --git a/tests/test_validation.py b/tests/test_validation.py index 855ff20e0..a644a3d5f 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -41,7 +41,8 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer): def validate_renamed(self, value): if len(value) < 3: - raise serializers.ValidationError('Minimum 3 characters.') + raise serializers.ValidationError('Minimum 3 characters.', + code='min_length') return value class Meta: @@ -91,11 +92,9 @@ class TestAvoidValidation(TestCase): def test_serializer_errors_has_only_invalid_data_error(self): serializer = ValidationSerializer(data='invalid data') self.assertFalse(serializer.is_valid()) - self.assertDictEqual(serializer.errors, { - 'non_field_errors': [ - 'Invalid data. Expected a dictionary, but got %s.' % type('').__name__ - ] - }) + self.assertEqual(serializer.errors['non_field_errors'][0].detail, + ['Invalid data. Expected a dictionary, but got %s.' % type('').__name__]) + self.assertEqual(serializer.errors['non_field_errors'][0].code, 'invalid') # regression tests for issue: 1493 @@ -123,7 +122,9 @@ class TestMaxValueValidatorValidation(TestCase): def test_max_value_validation_serializer_fails(self): serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) self.assertFalse(serializer.is_valid()) - self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + self.assertEqual(['Ensure this value is less than or equal to 100.'], serializer.errors['number_value'].detail[0].detail) + self.assertEqual(None, serializer.errors['number_value'].code) + self.assertEqual('max_value', serializer.errors['number_value'].detail[0].code) def test_max_value_validation_success(self): obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 000000000..7e6513127 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,95 @@ +import pytest +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.exceptions import ValidationError +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithCode(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + if not exc.code: + errors = { + field_name: { + 'code': e.code, + 'message': e.detail + } for field_name, e in exc.detail.items() + } + else: + errors = { + 'code': exc.code, + 'detail': exc.detail + } + return Response(errors, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': { + 'message': ['This field is required.'], + 'code': 'required', + }, + 'integer': { + 'message': ['This field is required.'], + 'code': 'required' + }, + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_validation_error_requires_no_code_for_structured_errors(self): + """ + ValidationError can hold a list or dictionary of simple errors, in + which case the code is no longer meaningful and should not be + specified. + """ + with pytest.raises(AssertionError): + serializers.ValidationError([ValidationError("test-detail", "test-code")], code='min_value') + + with pytest.raises(AssertionError): + serializers.ValidationError({}, code='min_value') + + def test_validation_error_stores_error_code(self): + exception = serializers.ValidationError("", code='min_value') + assert exception.code == 'min_value' + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) diff --git a/tests/test_validators.py b/tests/test_validators.py index acaaf5743..b86f1ed16 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -47,8 +47,11 @@ class TestUniquenessValidation(TestCase): def test_is_not_unique(self): data = {'username': 'existing'} serializer = UniquenessSerializer(data=data) + assert not serializer.is_valid() - assert serializer.errors == {'username': ['UniquenessModel with this username already exists.']} + assert serializer.errors['username'].code is None + assert serializer.errors['username'].detail[0].code == 'unique' + assert serializer.errors['username'].detail[0].detail == ['UniquenessModel with this username already exists.'] def test_is_unique(self): data = {'username': 'other'} @@ -150,11 +153,9 @@ class TestUniquenessTogetherValidation(TestCase): data = {'race_name': 'example', 'position': 2} serializer = UniquenessTogetherSerializer(data=data) assert not serializer.is_valid() - assert serializer.errors == { - 'non_field_errors': [ - 'The fields race_name, position must make a unique set.' - ] - } + assert serializer.errors['non_field_errors'][0].code == 'unique' + assert serializer.errors['non_field_errors'][0].detail == [ + 'The fields race_name, position must make a unique set.'] def test_is_unique_together(self): """ @@ -189,9 +190,8 @@ class TestUniquenessTogetherValidation(TestCase): data = {'position': 2} serializer = UniquenessTogetherSerializer(data=data, partial=True) assert not serializer.is_valid() - assert serializer.errors == { - 'race_name': ['This field is required.'] - } + assert serializer.errors['race_name'][0].code == 'required' + assert serializer.errors['race_name'][0].detail == ['This field is required.'] def test_ignore_excluded_fields(self): """ @@ -278,9 +278,8 @@ class TestUniquenessForDateValidation(TestCase): data = {'slug': 'existing', 'published': '2000-01-01'} serializer = UniqueForDateSerializer(data=data) assert not serializer.is_valid() - assert serializer.errors == { - 'slug': ['This field must be unique for the "published" date.'] - } + assert serializer.errors['slug'][0].code == 'unique' + assert serializer.errors['slug'][0].detail == ['This field must be unique for the "published" date.'] def test_is_unique_for_date(self): """