This commit is contained in:
Tom Christie 2016-10-10 19:49:59 +01:00
parent 7943429dab
commit 2eb62d4b23
5 changed files with 65 additions and 39 deletions

View File

@ -17,21 +17,21 @@ from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _force_text_recursive(data, code=None): def _get_error_details(data, default_code=None):
""" """
Descend into a nested data structure, forcing any Descend into a nested data structure, forcing any
lazy translation strings or strings into `ErrorMessage`. lazy translation strings or strings into `ErrorMessage`.
""" """
if isinstance(data, list): if isinstance(data, list):
ret = [ ret = [
_force_text_recursive(item, code) for item in data _get_error_details(item, default_code) for item in data
] ]
if isinstance(data, ReturnList): if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer) return ReturnList(ret, serializer=data.serializer)
return ret return ret
elif isinstance(data, dict): elif isinstance(data, dict):
ret = { ret = {
key: _force_text_recursive(value, code) key: _get_error_details(value, default_code)
for key, value in data.items() for key, value in data.items()
} }
if isinstance(data, ReturnDict): if isinstance(data, ReturnDict):
@ -39,15 +39,18 @@ def _force_text_recursive(data, code=None):
return ret return ret
text = force_text(data) text = force_text(data)
code = getattr(data, 'code', code or 'invalid') code = getattr(data, 'code', default_code)
return ErrorMessage(text, code) return ErrorDetail(text, code)
class ErrorMessage(six.text_type): class ErrorDetail(six.text_type):
"""
A string-like object that can additionally
"""
code = None code = None
def __new__(cls, string, code=None): def __new__(cls, string, code=None):
self = super(ErrorMessage, cls).__new__(cls, string) self = super(ErrorDetail, cls).__new__(cls, string)
self.code = code self.code = code
return self return self
@ -85,7 +88,13 @@ class ValidationError(APIException):
# The details should always be coerced to a list if not already. # The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list): if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail] detail = [detail]
self.detail = _force_text_recursive(detail, code=code)
if code is None:
default_code = 'invalid'
else:
default_code = code
self.detail = _get_error_details(detail, default_code)
def __str__(self): def __str__(self):
return six.text_type(self.detail) return six.text_type(self.detail)

View File

@ -34,7 +34,7 @@ from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
get_remote_field, unicode_repr, unicode_to_repr, value_from_object get_remote_field, unicode_repr, unicode_to_repr, value_from_object
) )
from rest_framework.exceptions import ErrorMessage, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation from rest_framework.utils import html, humanize_datetime, representation
@ -224,14 +224,14 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None):
yield Option(value='n/a', display_text=cutoff_text, disabled=True) yield Option(value='n/a', display_text=cutoff_text, disabled=True)
def get_error_messages(exc_info): def get_error_detail(exc_info):
""" """
Given a Django ValidationError, return a list of ErrorMessage, Given a Django ValidationError, return a list of ErrorDetail,
with the `code` populated. with the `code` populated.
""" """
code = getattr(exc_info, 'code', None) or 'invalid' code = getattr(exc_info, 'code', None) or 'invalid'
return [ return [
ErrorMessage(msg, code=code) ErrorDetail(msg, code=code)
for msg in exc_info.messages for msg in exc_info.messages
] ]
@ -537,7 +537,7 @@ class Field(object):
raise raise
errors.extend(exc.detail) errors.extend(exc.detail)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(get_error_messages(exc)) errors.extend(get_error_detail(exc))
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)

View File

@ -291,32 +291,29 @@ class SerializerMetaclass(type):
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
def get_validation_error_detail(exc): def as_serializer_error(exc):
assert isinstance(exc, (ValidationError, DjangoValidationError)) assert isinstance(exc, (ValidationError, DjangoValidationError))
if isinstance(exc, DjangoValidationError): if isinstance(exc, DjangoValidationError):
# Normally you should raise `serializers.ValidationError` detail = get_error_detail(exc)
# inside your codebase, but we handle Django's validation else:
# exception class as well for simpler compat. detail = exc.detail
# Eg. Calling Model.clean() explicitly inside Serializer.validate()
return { if isinstance(detail, dict):
api_settings.NON_FIELD_ERRORS_KEY: get_error_messages(exc)
}
elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
# Here we ensure that all the values are *lists* of errors. # Here we ensure that all the values are *lists* of errors.
return { return {
key: value if isinstance(value, (list, dict)) else [value] key: value if isinstance(value, (list, dict)) else [value]
for key, value in exc.detail.items() for key, value in detail.items()
} }
elif isinstance(exc.detail, list): elif isinstance(detail, list):
# Errors raised as a list are non-field errors. # Errors raised as a list are non-field errors.
return { return {
api_settings.NON_FIELD_ERRORS_KEY: exc.detail api_settings.NON_FIELD_ERRORS_KEY: detail
} }
# Errors raised as a string are non-field errors. # Errors raised as a string are non-field errors.
return { return {
api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] api_settings.NON_FIELD_ERRORS_KEY: [detail]
} }
@ -410,7 +407,7 @@ class Serializer(BaseSerializer):
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=get_validation_error_detail(exc)) raise ValidationError(detail=as_serializer_error(exc))
return value return value
@ -440,7 +437,7 @@ class Serializer(BaseSerializer):
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = exc.detail errors[field.field_name] = exc.detail
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors[field.field_name] = get_validation_error_detail(exc) errors[field.field_name] = get_error_detail(exc)
except SkipField: except SkipField:
pass pass
else: else:
@ -564,7 +561,7 @@ class ListSerializer(BaseSerializer):
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=get_validation_error_detail(exc)) raise ValidationError(detail=as_serializer_error(exc))
return value return value

View File

@ -3,19 +3,39 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import ErrorMessage, _force_text_recursive from rest_framework.exceptions import ErrorDetail, _get_error_details
class ExceptionTestCase(TestCase): class ExceptionTestCase(TestCase):
def test_force_text_recursive(self): def test_get_error_details(self):
s = "sfdsfggiuytraetfdlklj" example = "string"
self.assertEqual(_force_text_recursive(_(s)), s) lazy_example = _(example)
assert isinstance(_force_text_recursive(_(s)), ErrorMessage)
self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s) self.assertEqual(
assert isinstance(_force_text_recursive({'a': _(s)})['a'], ErrorMessage) _get_error_details(lazy_example),
example
)
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s) self.assertEqual(
assert isinstance(_force_text_recursive([[_(s)]])[0][0], ErrorMessage) _get_error_details({'nested': lazy_example})['nested'],
example
)
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
self.assertEqual(
_get_error_details([[lazy_example]])[0][0],
example
)
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)

View File

@ -60,7 +60,7 @@ class TestNestedValidationError(TestCase):
} }
}) })
self.assertEqual(serializers.get_validation_error_detail(e), { self.assertEqual(serializers.as_serializer_error(e), {
'nested': { 'nested': {
'field': ['error'], 'field': ['error'],
} }