This commit is contained in:
Tomasz Rydzyński 2016-02-17 06:56:40 +00:00
commit 262857f149
11 changed files with 161 additions and 22 deletions

View File

@ -18,13 +18,22 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization'
)
else: else:
msg = _('Unable to log in with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization'
)
else: else:
msg = _('Must include "username" and "password".') msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg) raise serializers.ValidationError(
msg,
code='authorization'
)
attrs['user'] = user attrs['user'] = user
return attrs return attrs

View File

@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext from django.utils.translation import ungettext
from rest_framework import status from rest_framework import status
from rest_framework.settings import api_settings
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
@ -58,6 +59,14 @@ class APIException(Exception):
return self.detail return self.detail
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationError.build_detail(msg, code=code)
for msg in exc_info.messages
]
# The recommended style for using `ValidationError` is to keep it namespaced # The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's # under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example: # built in `ValidationError`. For example:
@ -68,12 +77,40 @@ class APIException(Exception):
class ValidationError(APIException): class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail): @staticmethod
def build_detail(detail, code=None):
"""
Create error's representation.
This method is a helper that should be used when building compound
ValidationErrors directly (i.e. a whole list at once). Thanks to that
extra call, users have a customization point where they can tune how
much information about an error they want to see in the final output.
"""
if api_settings.REQUIRE_ERROR_CODES:
assert code is not None, (
'The `code` argument is required for single errors. '
'Strict checking of `code` is enabled with '
'REQUIRE_ERROR_CODES settings key.'
)
return detail
def __init__(self, detail, code=None):
# For validation errors the 'detail' key is always required. # For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already. # The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list): if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail] detail = [self.build_detail(detail, code)]
else:
if api_settings.REQUIRE_ERROR_CODES:
assert code is None, (
'The `code` argument must not be set for compound '
'errors. Strict checking of `code` is enabled with '
'REQUIRE_ERROR_CODES settings key.'
)
self.detail = _force_text_recursive(detail) self.detail = _force_text_recursive(detail)
self.code = code
def __str__(self): def __str__(self):
return six.text_type(self.detail) return six.text_type(self.detail)

View File

@ -31,7 +31,9 @@ from rest_framework.compat import (
MinValueValidator, duration_string, parse_duration, unicode_repr, MinValueValidator, duration_string, parse_duration, unicode_repr,
unicode_to_repr unicode_to_repr
) )
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import (
ValidationError, build_error_from_django_validation_error
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation from rest_framework.utils import html, humanize_datetime, representation
@ -503,7 +505,9 @@ class Field(object):
raise raise
errors.extend(exc.detail) errors.extend(exc.detail)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(exc.messages) errors.extend(
build_error_from_django_validation_error(exc)
)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
@ -541,7 +545,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
message_string = msg.format(**kwargs) message_string = msg.format(**kwargs)
raise ValidationError(message_string) raise ValidationError(message_string, code=key)
@cached_property @cached_property
def root(self): def root(self):

View File

@ -20,6 +20,7 @@ from django.db.models.fields import FieldDoesNotExist
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions
from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import DurationField as ModelDurationField
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.compat import postgres_fields, unicode_to_repr
@ -300,7 +301,8 @@ def get_validation_error_detail(exc):
# exception class as well for simpler compat. # exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate() # Eg. Calling Model.clean() explicitly inside Serializer.validate()
return { return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) api_settings.NON_FIELD_ERRORS_KEY:
exceptions.build_error_from_django_validation_error(exc)
} }
elif isinstance(exc.detail, dict): elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
@ -422,8 +424,9 @@ class Serializer(BaseSerializer):
message = self.error_messages['invalid'].format( message = self.error_messages['invalid'].format(
datatype=type(data).__name__ datatype=type(data).__name__
) )
error = ValidationError.build_detail(message, code='invalid')
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [error]
}) })
ret = OrderedDict() ret = OrderedDict()
@ -440,7 +443,9 @@ class Serializer(BaseSerializer):
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = exc.detail errors[field.field_name] = exc.detail
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages) errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
)
except SkipField: except SkipField:
pass pass
else: else:
@ -560,7 +565,9 @@ class ListSerializer(BaseSerializer):
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=get_validation_error_detail(exc)) raise ValidationError(
detail=get_validation_error_detail(exc)
)
return value return value
@ -575,8 +582,12 @@ class ListSerializer(BaseSerializer):
message = self.error_messages['not_a_list'].format( message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__ input_type=type(data).__name__
) )
error = ValidationError.build_detail(
message,
code='not_a_list'
)
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [error]
}) })
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:

View File

@ -85,6 +85,7 @@ DEFAULTS = {
# Exception handling # Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
'NON_FIELD_ERRORS_KEY': 'non_field_errors', 'NON_FIELD_ERRORS_KEY': 'non_field_errors',
'REQUIRE_ERROR_CODES': False,
# Testing # Testing
'TEST_REQUEST_RENDERER_CLASSES': ( 'TEST_REQUEST_RENDERER_CLASSES': (

View File

@ -60,7 +60,7 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset) queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset) queryset = self.exclude_current_instance(queryset)
if queryset.exists(): if queryset.exists():
raise ValidationError(self.message) raise ValidationError(self.message, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % ( return unicode_to_repr('<%s(queryset=%s)>' % (
@ -101,7 +101,9 @@ class UniqueTogetherValidator(object):
return return
missing = { missing = {
field_name: self.missing_message field_name: ValidationError.build_detail(
self.missing_message,
code='required')
for field_name in self.fields for field_name in self.fields
if field_name not in attrs if field_name not in attrs
} }
@ -147,7 +149,8 @@ class UniqueTogetherValidator(object):
] ]
if None not in checked_values and queryset.exists(): if None not in checked_values and queryset.exists():
field_names = ', '.join(self.fields) field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names)) raise ValidationError(self.message.format(field_names=field_names),
code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -185,7 +188,9 @@ class BaseUniqueForValidator(object):
'required' state on the fields they are applied to. 'required' state on the fields they are applied to.
""" """
missing = { missing = {
field_name: self.missing_message field_name: ValidationError.build_detail(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field] for field_name in [self.field, self.date_field]
if field_name not in attrs if field_name not in attrs
} }
@ -211,7 +216,8 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset)
if queryset.exists(): if queryset.exists():
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message}) error = ValidationError.build_detail(message, code='unique')
raise ValidationError({self.field: error})
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -1399,7 +1399,10 @@ class TestFieldFieldWithName(FieldValues):
# call into it's regular validation, or require PIL for testing. # call into it's regular validation, or require PIL for testing.
class FailImageValidation(object): class FailImageValidation(object):
def to_python(self, value): def to_python(self, value):
raise serializers.ValidationError(self.error_messages['invalid_image']) raise serializers.ValidationError(
self.error_messages['invalid_image'],
code='invalid_image'
)
class PassImageValidation(object): class PassImageValidation(object):

View File

@ -7,6 +7,7 @@ import pytest
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import unicode_repr from rest_framework.compat import unicode_repr
from rest_framework.fields import DjangoValidationError
from .utils import MockObject from .utils import MockObject
@ -69,7 +70,10 @@ class TestValidateMethod:
integer = serializers.IntegerField() integer = serializers.IntegerField()
def validate(self, attrs): def validate(self, attrs):
raise serializers.ValidationError('Non field error') raise serializers.ValidationError(
'Non field error',
code='test'
)
serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
assert not serializer.is_valid() assert not serializer.is_valid()
@ -309,3 +313,25 @@ class TestCacheSerializerData:
pickled = pickle.dumps(serializer.data) pickled = pickle.dumps(serializer.data)
data = pickle.loads(pickled) data = pickle.loads(pickled)
assert data == {'field1': 'a', 'field2': 'b'} assert data == {'field1': 'a', 'field2': 'b'}
class TestGetValidationErrorDetail:
def test_get_validation_error_detail_converts_django_errors(self):
exc = DjangoValidationError("Missing field.", code='required')
detail = serializers.get_validation_error_detail(exc)
assert detail == {'non_field_errors': ['Missing field.']}
class TestCapturingDjangoValidationError:
def test_django_validation_error_on_a_field_is_converted(self):
class ExampleSerializer(serializers.Serializer):
field = serializers.CharField()
def validate_field(self, value):
raise DjangoValidationError(
'validation failed'
)
serializer = ExampleSerializer(data={'field': 'a'})
assert not serializer.is_valid()
assert serializer.errors == {'field': ['validation failed']}

View File

@ -280,7 +280,10 @@ class TestListSerializerClass:
def test_list_serializer_class_validate(self): def test_list_serializer_class_validate(self):
class CustomListSerializer(serializers.ListSerializer): class CustomListSerializer(serializers.ListSerializer):
def validate(self, attrs): def validate(self, attrs):
raise serializers.ValidationError('Non field error') raise serializers.ValidationError(
'Non field error',
code='test'
)
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
class Meta: class Meta:

View File

@ -41,7 +41,8 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer):
def validate_renamed(self, value): def validate_renamed(self, value):
if len(value) < 3: if len(value) < 3:
raise serializers.ValidationError('Minimum 3 characters.') raise serializers.ValidationError('Minimum 3 characters.',
code='min_length')
return value return value
class Meta: class Meta:

View File

@ -0,0 +1,38 @@
import pytest
from django.test import TestCase
from rest_framework import serializers
from rest_framework.settings import api_settings
class TestMandatoryErrorCodeArgument(TestCase):
"""
If mandatory error code is enabled in settings, it should prevent throwing
ValidationError without the code set.
"""
def setUp(self):
self.DEFAULT_REQUIRE_ERROR_CODES = api_settings.REQUIRE_ERROR_CODES
api_settings.REQUIRE_ERROR_CODES = True
def tearDown(self):
api_settings.REQUIRE_ERROR_CODES = self.DEFAULT_REQUIRE_ERROR_CODES
def test_validation_error_requires_code_for_simple_messages(self):
with pytest.raises(AssertionError):
serializers.ValidationError("")
def test_validation_error_requires_no_code_for_structured_errors(self):
"""
ValidationError can hold a list or dictionary of simple errors, in
which case the code is no longer meaningful and should not be
specified.
"""
with pytest.raises(AssertionError):
serializers.ValidationError([], code='min_value')
with pytest.raises(AssertionError):
serializers.ValidationError({}, code='min_value')
def test_validation_error_stores_error_code(self):
exception = serializers.ValidationError("", code='min_value')
assert exception.code == 'min_value'