exceptions.ValidationFailed, not Django's ValidationError

This commit is contained in:
Tom Christie 2014-10-10 14:16:09 +01:00
parent a0e852a4d5
commit d9a199ca0d
9 changed files with 107 additions and 91 deletions

View File

@ -191,7 +191,7 @@ Using the `depth` option on `ModelSerializer` will now create **read-only nested
def create(self, validated_data): def create(self, validated_data):
profile_data = validated_data.pop['profile'] profile_data = validated_data.pop['profile']
user = User.objects.create(**validated_data) user = User.objects.create(**validated_data)
profile = Profile.objects.create(user=user, **profile_data) Profile.objects.create(user=user, **profile_data)
return user return user
The single-step object creation makes this far simpler and more obvious than the previous `.restore_object()` behavior. The single-step object creation makes this far simpler and more obvious than the previous `.restore_object()` behavior.
@ -223,10 +223,6 @@ We can now inspect the serializer representation in the Django shell, using `pyt
rating = IntegerField() rating = IntegerField()
created_by = PrimaryKeyRelatedField(queryset=User.objects.all()) created_by = PrimaryKeyRelatedField(queryset=User.objects.all())
#### Always use `fields`, not `exclude`.
The `exclude` option on `ModelSerializer` is no longer available. You should use the more explicit `fields` option instead.
#### The `extra_kwargs` option. #### The `extra_kwargs` option.
The `write_only_fields` option on `ModelSerializer` has been moved to `PendingDeprecation` and replaced with a more generic `extra_kwargs`. The `write_only_fields` option on `ModelSerializer` has been moved to `PendingDeprecation` and replaced with a more generic `extra_kwargs`.

View File

@ -1,7 +1,7 @@
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import exceptions, serializers
class AuthTokenSerializer(serializers.Serializer): class AuthTokenSerializer(serializers.Serializer):
@ -18,13 +18,13 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise exceptions.ValidationFailed(msg)
else: else:
msg = _('Unable to log in with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg) raise exceptions.ValidationFailed(msg)
else: else:
msg = _('Must include "username" and "password"') msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg) raise exceptions.ValidationFailed(msg)
attrs['user'] = user attrs['user'] = user
return attrs return attrs

View File

@ -24,6 +24,20 @@ class APIException(Exception):
return self.detail return self.detail
class ValidationFailed(APIException):
status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = detail
def __str__(self):
return str(self.detail)
class ParseError(APIException): class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.' default_detail = 'Malformed request.'

View File

@ -1,7 +1,8 @@
from django import forms
from django.conf import settings from django.conf import settings
from django.core import validators from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.exceptions import ObjectDoesNotExist
from django.core.exceptions import ValidationError as DjangoValidationError
from django.forms import ImageField as DjangoImageField
from django.utils import six, timezone from django.utils import six, timezone
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.dateparse import parse_date, parse_datetime, parse_time
@ -12,6 +13,7 @@ from rest_framework.compat import (
smart_text, EmailValidator, MinValueValidator, MaxValueValidator, smart_text, EmailValidator, MinValueValidator, MaxValueValidator,
MinLengthValidator, MaxLengthValidator, URLValidator MinLengthValidator, MaxLengthValidator, URLValidator
) )
from rest_framework.exceptions import ValidationFailed
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime from rest_framework.utils import html, representation, humanize_datetime
import copy import copy
@ -98,7 +100,7 @@ NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = ( MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does ' 'ValidationFailed raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.' 'not exist in the `error_messages` dictionary.'
) )
@ -263,7 +265,7 @@ class Field(object):
def run_validators(self, value): def run_validators(self, value):
""" """
Test the given value against all the validators on the field, Test the given value against all the validators on the field,
and either raise a `ValidationError` or simply return. and either raise a `ValidationFailed` or simply return.
""" """
errors = [] errors = []
for validator in self.validators: for validator in self.validators:
@ -271,10 +273,12 @@ class Field(object):
validator.serializer_field = self validator.serializer_field = self
try: try:
validator(value) validator(value)
except ValidationError as exc: except ValidationFailed as exc:
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(exc.messages) errors.extend(exc.messages)
if errors: if errors:
raise ValidationError(errors) raise ValidationFailed(errors)
def validate(self, value): def validate(self, value):
pass pass
@ -301,7 +305,8 @@ class Field(object):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs)) message_string = msg.format(**kwargs)
raise ValidationFailed(message_string)
@property @property
def root(self): def root(self):
@ -946,7 +951,7 @@ class ImageField(FileField):
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._DjangoImageField = kwargs.pop('_DjangoImageField', forms.ImageField) self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField)
super(ImageField, self).__init__(*args, **kwargs) super(ImageField, self).__init__(*args, **kwargs)
def to_internal_value(self, data): def to_internal_value(self, data):

View File

@ -10,10 +10,11 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from django.core.exceptions import ImproperlyConfigured, ValidationError from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.exceptions import ValidationFailed
from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation from rest_framework.utils import html, model_meta, representation
@ -100,14 +101,14 @@ class BaseSerializer(Field):
if not hasattr(self, '_validated_data'): if not hasattr(self, '_validated_data'):
try: try:
self._validated_data = self.run_validation(self._initial_data) self._validated_data = self.run_validation(self._initial_data)
except ValidationError as exc: except ValidationFailed as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.message_dict self._errors = exc.detail
else: else:
self._errors = {} self._errors = {}
if self._errors and raise_exception: if self._errors and raise_exception:
raise ValidationError(self._errors) raise ValidationFailed(self._errors)
return not bool(self._errors) return not bool(self._errors)
@ -175,23 +176,33 @@ class BoundField(object):
def __getattr__(self, attr_name): def __getattr__(self, attr_name):
return getattr(self._field, attr_name) return getattr(self._field, attr_name)
def __iter__(self):
for field in self.fields.values():
yield self[field.field_name]
def __getitem__(self, key):
assert hasattr(self, 'fields'), '"%s" is not a nested field. Cannot perform indexing.' % self.name
field = self.fields[key]
value = self.value.get(key) if self.value else None
error = self.errors.get(key) if self.errors else None
return BoundField(field, value, error, prefix=self.name + '.')
@property @property
def _proxy_class(self): def _proxy_class(self):
return self._field.__class__ return self._field.__class__
def __repr__(self): def __repr__(self):
return '<%s value=%s errors=%s>' % (self.__class__.__name__, self.value, self.errors) return '<%s value=%s errors=%s>' % (
self.__class__.__name__, self.value, self.errors
)
class NestedBoundField(BoundField):
"""
This BoundField additionally implements __iter__ and __getitem__
in order to support nested bound fields. This class is the type of
BoundField that is used for serializer fields.
"""
def __iter__(self):
for field in self.fields.values():
yield self[field.field_name]
def __getitem__(self, key):
field = self.fields[key]
value = self.value.get(key) if self.value else None
error = self.errors.get(key) if self.errors else None
if isinstance(field, Serializer):
return NestedBoundField(field, value, error, prefix=self.name + '.')
return BoundField(field, value, error, prefix=self.name + '.')
class BindingDict(object): class BindingDict(object):
@ -308,7 +319,7 @@ class Serializer(BaseSerializer):
return None return None
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValidationError({ raise ValidationFailed({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
}) })
@ -317,9 +328,9 @@ class Serializer(BaseSerializer):
self.run_validators(value) self.run_validators(value)
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except ValidationError as exc: except ValidationFailed as exc:
raise ValidationError({ raise ValidationFailed({
api_settings.NON_FIELD_ERRORS_KEY: exc.messages api_settings.NON_FIELD_ERRORS_KEY: exc.detail
}) })
return value return value
@ -338,15 +349,15 @@ class Serializer(BaseSerializer):
validated_value = field.run_validation(primitive_value) validated_value = field.run_validation(primitive_value)
if validate_method is not None: if validate_method is not None:
validated_value = validate_method(validated_value) validated_value = validate_method(validated_value)
except ValidationError as exc: except ValidationFailed as exc:
errors[field.field_name] = exc.messages errors[field.field_name] = exc.detail
except SkipField: except SkipField:
pass pass
else: else:
set_value(ret, field.source_attrs, validated_value) set_value(ret, field.source_attrs, validated_value)
if errors: if errors:
raise ValidationError(errors) raise ValidationFailed(errors)
return ret return ret
@ -385,6 +396,8 @@ class Serializer(BaseSerializer):
field = self.fields[key] field = self.fields[key]
value = self.data.get(key) value = self.data.get(key)
error = self.errors.get(key) if hasattr(self, '_errors') else None error = self.errors.get(key) if hasattr(self, '_errors') else None
if isinstance(field, Serializer):
return NestedBoundField(field, value, error)
return BoundField(field, value, error) return BoundField(field, value, error)
@ -538,9 +551,12 @@ class ModelSerializer(Serializer):
ret = SortedDict() ret = SortedDict()
model = getattr(self.Meta, 'model') model = getattr(self.Meta, 'model')
fields = getattr(self.Meta, 'fields', None) fields = getattr(self.Meta, 'fields', None)
exclude = getattr(self.Meta, 'exclude', None)
depth = getattr(self.Meta, 'depth', 0) depth = getattr(self.Meta, 'depth', 0)
extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
assert not fields and exclude, "Cannot set both 'fields' and 'exclude'."
extra_kwargs = self._include_additional_options(extra_kwargs) extra_kwargs = self._include_additional_options(extra_kwargs)
# Retrieve metadata about fields & relationships on the model class. # Retrieve metadata about fields & relationships on the model class.
@ -551,12 +567,6 @@ class ModelSerializer(Serializer):
fields = self._get_default_field_names(declared_fields, info) fields = self._get_default_field_names(declared_fields, info)
exclude = getattr(self.Meta, 'exclude', None) exclude = getattr(self.Meta, 'exclude', None)
if exclude is not None: if exclude is not None:
warnings.warn(
"The `Meta.exclude` option is pending deprecation. "
"Use the explicit `Meta.fields` instead.",
PendingDeprecationWarning,
stacklevel=3
)
for field_name in exclude: for field_name in exclude:
fields.remove(field_name) fields.remove(field_name)

View File

@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions from rest_framework import status, exceptions
@ -63,27 +63,20 @@ def exception_handler(exc):
if getattr(exc, 'wait', None): if getattr(exc, 'wait', None):
headers['Retry-After'] = '%d' % exc.wait headers['Retry-After'] = '%d' % exc.wait
return Response({'detail': exc.detail}, if isinstance(exc.detail, (list, dict)):
status=exc.status_code, data = exc.detail
headers=headers) else:
data = {'detail': exc.detail}
elif isinstance(exc, ValidationError): return Response(data, status=exc.status_code, headers=headers)
# ValidationErrors may include the non-field key named '__all__'.
# When returning a response we map this to a key name that can be
# modified in settings.
if NON_FIELD_ERRORS in exc.message_dict:
errors = exc.message_dict.pop(NON_FIELD_ERRORS)
exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors
return Response(exc.message_dict,
status=status.HTTP_400_BAD_REQUEST)
elif isinstance(exc, Http404): elif isinstance(exc, Http404):
return Response({'detail': 'Not found'}, data = {'detail': 'Not found'}
status=status.HTTP_404_NOT_FOUND) return Response(data, status=status.HTTP_404_NOT_FOUND)
elif isinstance(exc, PermissionDenied): elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'}, data = {'detail': 'Permission denied'}
status=status.HTTP_403_FORBIDDEN) return Response(data, status=status.HTTP_403_FORBIDDEN)
# Note: Unhandled exceptions will raise a 500 error. # Note: Unhandled exceptions will raise a 500 error.
return None return None

View File

@ -1,7 +1,6 @@
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import ValidationError
from django.utils import timezone from django.utils import timezone
from rest_framework import fields, serializers from rest_framework import exceptions, fields, serializers
import datetime import datetime
import django import django
import pytest import pytest
@ -19,9 +18,9 @@ class TestEmpty:
By default a field must be included in the input. By default a field must be included in the input.
""" """
field = fields.IntegerField() field = fields.IntegerField()
with pytest.raises(fields.ValidationError) as exc_info: with pytest.raises(exceptions.ValidationFailed) as exc_info:
field.run_validation() field.run_validation()
assert exc_info.value.messages == ['This field is required.'] assert exc_info.value.detail == ['This field is required.']
def test_not_required(self): def test_not_required(self):
""" """
@ -36,9 +35,9 @@ class TestEmpty:
By default `None` is not a valid input. By default `None` is not a valid input.
""" """
field = fields.IntegerField() field = fields.IntegerField()
with pytest.raises(fields.ValidationError) as exc_info: with pytest.raises(exceptions.ValidationFailed) as exc_info:
field.run_validation(None) field.run_validation(None)
assert exc_info.value.messages == ['This field may not be null.'] assert exc_info.value.detail == ['This field may not be null.']
def test_allow_null(self): def test_allow_null(self):
""" """
@ -53,9 +52,9 @@ class TestEmpty:
By default '' is not a valid input. By default '' is not a valid input.
""" """
field = fields.CharField() field = fields.CharField()
with pytest.raises(fields.ValidationError) as exc_info: with pytest.raises(exceptions.ValidationFailed) as exc_info:
field.run_validation('') field.run_validation('')
assert exc_info.value.messages == ['This field may not be blank.'] assert exc_info.value.detail == ['This field may not be blank.']
def test_allow_blank(self): def test_allow_blank(self):
""" """
@ -190,7 +189,7 @@ class TestInvalidErrorKey:
with pytest.raises(AssertionError) as exc_info: with pytest.raises(AssertionError) as exc_info:
self.field.to_native(123) self.field.to_native(123)
expected = ( expected = (
'ValidationError raised by `ExampleField`, but error key ' 'ValidationFailed raised by `ExampleField`, but error key '
'`incorrect` does not exist in the `error_messages` dictionary.' '`incorrect` does not exist in the `error_messages` dictionary.'
) )
assert str(exc_info.value) == expected assert str(exc_info.value) == expected
@ -244,9 +243,9 @@ class FieldValues:
Ensure that invalid values raise the expected validation error. Ensure that invalid values raise the expected validation error.
""" """
for input_value, expected_failure in get_items(self.invalid_inputs): for input_value, expected_failure in get_items(self.invalid_inputs):
with pytest.raises(fields.ValidationError) as exc_info: with pytest.raises(exceptions.ValidationFailed) as exc_info:
self.field.run_validation(input_value) self.field.run_validation(input_value)
assert exc_info.value.messages == expected_failure assert exc_info.value.detail == expected_failure
def test_outputs(self): def test_outputs(self):
for output_value, expected_output in get_items(self.outputs): for output_value, expected_output in get_items(self.outputs):
@ -901,7 +900,7 @@ class TestFieldFieldWithName(FieldValues):
# call into it's regular validation, or require PIL for testing. # call into it's regular validation, or require PIL for testing.
class FailImageValidation(object): class FailImageValidation(object):
def to_python(self, value): def to_python(self, value):
raise ValidationError(self.error_messages['invalid_image']) raise exceptions.ValidationFailed(self.error_messages['invalid_image'])
class PassImageValidation(object): class PassImageValidation(object):

View File

@ -1,6 +1,6 @@
from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
from django.core.exceptions import ImproperlyConfigured, ValidationError from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers from rest_framework import exceptions, serializers
from rest_framework.test import APISimpleTestCase from rest_framework.test import APISimpleTestCase
import pytest import pytest
@ -30,15 +30,15 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self): def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(ValidationError) as excinfo: with pytest.raises(exceptions.ValidationFailed) as excinfo:
self.field.to_internal_value(4) self.field.to_internal_value(4)
msg = excinfo.value.messages[0] msg = excinfo.value.detail[0]
assert msg == "Invalid pk '4' - object does not exist." assert msg == "Invalid pk '4' - object does not exist."
def test_pk_related_lookup_invalid_type(self): def test_pk_related_lookup_invalid_type(self):
with pytest.raises(ValidationError) as excinfo: with pytest.raises(exceptions.ValidationFailed) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.messages[0] msg = excinfo.value.detail[0]
assert msg == 'Incorrect type. Expected pk value, received BadType.' assert msg == 'Incorrect type. Expected pk value, received BadType.'
def test_pk_representation(self): def test_pk_representation(self):
@ -120,15 +120,15 @@ class TestSlugRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self): def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(ValidationError) as excinfo: with pytest.raises(exceptions.ValidationFailed) as excinfo:
self.field.to_internal_value('doesnotexist') self.field.to_internal_value('doesnotexist')
msg = excinfo.value.messages[0] msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.' assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self): def test_slug_related_lookup_invalid_type(self):
with pytest.raises(ValidationError) as excinfo: with pytest.raises(exceptions.ValidationFailed) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.messages[0] msg = excinfo.value.detail[0]
assert msg == 'Invalid value.' assert msg == 'Invalid value.'
def test_representation(self): def test_representation(self):

View File

@ -1,9 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.validators import MaxValueValidator from django.core.validators import MaxValueValidator
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers, status from rest_framework import exceptions, generics, serializers, status
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
@ -38,7 +37,7 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer):
def validate_renamed(self, value): def validate_renamed(self, value):
if len(value) < 3: if len(value) < 3:
raise serializers.ValidationError('Minimum 3 characters.') raise exceptions.ValidationFailed('Minimum 3 characters.')
return value return value
class Meta: class Meta:
@ -74,10 +73,10 @@ class ValidationSerializer(serializers.Serializer):
foo = serializers.CharField() foo = serializers.CharField()
def validate_foo(self, attrs, source): def validate_foo(self, attrs, source):
raise serializers.ValidationError("foo invalid") raise exceptions.ValidationFailed("foo invalid")
def validate(self, attrs): def validate(self, attrs):
raise serializers.ValidationError("serializer invalid") raise exceptions.ValidationFailed("serializer invalid")
class TestAvoidValidation(TestCase): class TestAvoidValidation(TestCase):
@ -159,7 +158,7 @@ class TestChoiceFieldChoicesValidate(TestCase):
value = self.CHOICES[0][0] value = self.CHOICES[0][0]
try: try:
f.to_internal_value(value) f.to_internal_value(value)
except ValidationError: except exceptions.ValidationFailed:
self.fail("Value %s does not validate" % str(value)) self.fail("Value %s does not validate" % str(value))
# def test_nested_choices(self): # def test_nested_choices(self):