diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index ce6bb9c79..abaac0c22 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -20,20 +20,18 @@ class AuthTokenSerializer(serializers.Serializer): msg = _('User account is disabled.') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') else: msg = _('Unable to log in with provided credentials.') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') + else: msg = _('Must include "username" and "password".') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index e23b7cd31..6e30834e6 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -61,7 +61,7 @@ class APIException(Exception): def build_error_from_django_validation_error(exc_info): code = getattr(exc_info, 'code', None) or 'invalid' return [ - ValidationError(msg, code=code) + ValidationErrorMessage(msg, code=code) for msg in exc_info.messages ] @@ -73,20 +73,30 @@ def build_error_from_django_validation_error(exc_info): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') +class ValidationErrorMessage(six.text_type): + code = None + + def __new__(cls, string, code=None, *args, **kwargs): + self = super(ValidationErrorMessage, cls).__new__( + cls, string, *args, **kwargs) + + self.code = code + return self + + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST def __init__(self, detail, code=None): + # If code is there, this means we are dealing with a message. + if code and not isinstance(detail, ValidationErrorMessage): + detail = ValidationErrorMessage(detail, code=code) + # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)): - assert code is None, ( - 'The `code` argument must not be set for compound errors.') - - self.detail = detail - self.code = code + self.detail = _force_text_recursive(detail) def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 39a5e3395..396259584 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -509,7 +509,7 @@ class Field(object): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.append(ValidationError(exc.detail, code=exc.code)) + errors.extend(exc.detail) except DjangoValidationError as exc: errors.extend( build_error_from_django_validation_error(exc) diff --git a/rest_framework/response.py b/rest_framework/response.py index e9ceb2741..4b863cb99 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -38,6 +38,7 @@ class Response(SimpleTemplateResponse): '`.error`. representation.' ) raise AssertionError(msg) + self.data = data self.template_name = template_name self.exception = exception diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a4dbc6449..5b3ef3770 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -25,6 +25,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import exceptions from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.exceptions import ValidationErrorMessage from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -220,14 +221,7 @@ class BaseSerializer(Field): self._errors = {} if self._errors and raise_exception: - return_errors = None - if isinstance(self._errors, list): - return_errors = ReturnList(self._errors, serializer=self) - elif isinstance(self._errors, dict): - return_errors = ReturnDict(self._errors, serializer=self) - - raise ValidationError(return_errors) - + raise ValidationError(self.errors) return not bool(self._errors) @property @@ -251,42 +245,12 @@ class BaseSerializer(Field): self._data = self.get_initial() return self._data - def _transform_to_legacy_errors(self, errors_to_transform): - # Do not mutate `errors_to_transform` here. - errors = ReturnDict(serializer=self) - for field_name, values in errors_to_transform.items(): - if isinstance(values, list): - errors[field_name] = values - continue - - if isinstance(values.detail, list): - errors[field_name] = [] - for value in values.detail: - if isinstance(value, ValidationError): - errors[field_name].extend(value.detail) - elif isinstance(value, list): - errors[field_name].extend(value) - else: - errors[field_name].append(value) - - elif isinstance(values.detail, dict): - errors[field_name] = {} - for sub_field_name, value in values.detail.items(): - errors[field_name][sub_field_name] = [] - for validation_error in value: - errors[field_name][sub_field_name].extend(validation_error.detail) - return errors - @property def errors(self): if not hasattr(self, '_errors'): msg = 'You must call `.is_valid()` before accessing `.errors`.' raise AssertionError(msg) - - if isinstance(self._errors, list): - return map(self._transform_to_legacy_errors, self._errors) - else: - return self._transform_to_legacy_errors(self._errors) + return self._errors @property def validated_data(self): @@ -461,7 +425,7 @@ class Serializer(BaseSerializer): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) - error = ValidationError(message, code='invalid') + error = ValidationErrorMessage(message, code='invalid') raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [error] }) @@ -478,7 +442,7 @@ class Serializer(BaseSerializer): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc + errors[field.field_name] = exc.detail except DjangoValidationError as exc: errors[field.field_name] = ( exceptions.build_error_from_django_validation_error(exc) @@ -621,7 +585,7 @@ class ListSerializer(BaseSerializer): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) - error = ValidationError( + error = ValidationErrorMessage( message, code='not_a_list' ) @@ -630,9 +594,11 @@ class ListSerializer(BaseSerializer): }) if not self.allow_empty and len(data) == 0: - message = self.error_messages['empty'] + message = ValidationErrorMessage( + self.error_messages['empty'], + code='empty_not_allowed') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')] + api_settings.NON_FIELD_ERRORS_KEY: [message] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 90483eeeb..3b8678a70 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.db import DataError from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ValidationError, ValidationErrorMessage from rest_framework.utils.representation import smart_repr @@ -120,9 +120,10 @@ class UniqueTogetherValidator(object): return missing = { - field_name: ValidationError( + field_name: ValidationErrorMessage( self.missing_message, code='required') + for field_name in self.fields if field_name not in attrs } @@ -168,8 +169,9 @@ class UniqueTogetherValidator(object): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names), - code='unique') + raise ValidationError( + self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -207,7 +209,7 @@ class BaseUniqueForValidator(object): 'required' state on the fields they are applied to. """ missing = { - field_name: ValidationError( + field_name: ValidationErrorMessage( self.missing_message, code='required') for field_name in [self.field, self.date_field] @@ -235,8 +237,9 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - error = ValidationError(message, code='unique') - raise ValidationError({self.field: error}) + raise ValidationError({ + self.field: ValidationErrorMessage(message, code='unique'), + }) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/rest_framework/views.py b/rest_framework/views.py index 8b6f060d4..15d8c6cde 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -71,7 +71,7 @@ def exception_handler(exc, context): headers['Retry-After'] = '%d' % exc.wait if isinstance(exc.detail, (list, dict)): - data = exc.detail.serializer.errors + data = exc.detail else: data = {'detail': exc.detail} diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 000000000..a9d244176 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,74 @@ +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithCode(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + return_errors = {} + for field_name, errors in exc.detail.items(): + return_errors[field_name] = [] + for error in errors: + return_errors[field_name].append({ + 'code': error.code, + 'message': error + }) + + return Response(return_errors, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': [{ + 'message': 'This field is required.', + 'code': 'required', + }], + 'integer': [{ + 'message': 'This field is required.', + 'code': 'required' + }], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data)