Deal with ValidationError instantiation

This commit is contained in:
Jonathan Liuti 2015-12-22 16:02:24 +01:00
parent 12e4b8f3b1
commit 2344d2270a
5 changed files with 63 additions and 88 deletions

View File

@ -7,6 +7,7 @@ In addition Django's built in 403 and 404 exceptions are handled.
from __future__ import unicode_literals from __future__ import unicode_literals
import math import math
from collections import namedtuple
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -61,7 +62,7 @@ class APIException(Exception):
def build_error_from_django_validation_error(exc_info): def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid' code = getattr(exc_info, 'code', None) or 'invalid'
return [ return [
(msg, code) ErrorDetails(msg, code)
for msg in exc_info.messages for msg in exc_info.messages
] ]
@ -72,60 +73,61 @@ def build_error_from_django_validation_error(exc_info):
# from rest_framework import serializers # from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid') # raise serializers.ValidationError('Value was invalid')
ErrorDetails = namedtuple('ErrorDetails', ['message', 'code'])
class ValidationError(APIException): class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
code = None code = None
def __init__(self, detail, code=None): def __init__(self, detail, code=None):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if code: if code:
self.full_details = [(detail, code)] self.full_details = ErrorDetails(detail, code)
else: else:
self.full_details = detail self.full_details = detail
if isinstance(self.full_details, tuple): if not isinstance(self.full_details, dict) \
self.detail, self.code = self.full_details and not isinstance(self.full_details, list):
self.detail = [self.detail] self.full_details = [self.full_details]
self.full_details = _force_text_recursive(self.full_details)
elif isinstance(self.full_details, list): self.detail = detail
if isinstance(self.full_details, list):
if isinstance(self.full_details, ReturnList): if isinstance(self.full_details, ReturnList):
self.detail = ReturnList(serializer=self.full_details.serializer) self.detail = ReturnList(
serializer=self.full_details.serializer)
else: else:
self.detail = [] self.detail = []
for error in self.full_details: for full_detail in self.full_details:
if isinstance(error, tuple): if isinstance(full_detail, ErrorDetails):
message, code = error self.detail.append(full_detail.message)
self.detail.append(message) elif isinstance(full_detail, dict):
elif isinstance(error, dict): if not full_detail:
self.detail = self.full_details self.detail.append(full_detail)
break for key, value in full_detail.items():
if isinstance(value, list):
self.detail.append(
{key: [item.message]
if isinstance(item, ErrorDetails)
else [item] for item in value})
elif isinstance(full_detail, list):
self.detail.extend(full_detail)
else:
self.detail.append(full_detail)
elif isinstance(self.full_details, dict): elif isinstance(self.full_details, dict):
if isinstance(self.full_details, ReturnDict): if isinstance(self.full_details, ReturnDict):
self.detail = ReturnDict(serializer=self.full_details.serializer) self.detail = ReturnDict(
serializer=self.full_details.serializer)
else: else:
self.detail = {} self.detail = {}
for field_name, full_detail in self.full_details.items():
for field_name, errors in self.full_details.items(): if isinstance(full_detail, list):
self.detail[field_name] = [] self.detail[field_name] = [
if isinstance(errors, tuple): item.message if isinstance(item, ErrorDetails) else item
message, code = errors for item in full_detail
self.detail[field_name].append(message) ]
elif isinstance(errors, list): else:
for error in errors: self.detail[field_name] = full_detail
if isinstance(error, tuple):
message, code = error
else:
message = error
if message:
self.detail[field_name].append(message)
else:
self.detail = [self.full_details]
self.detail = _force_text_recursive(self.detail)
def __str__(self): def __str__(self):
return six.text_type(self.detail) return six.text_type(self.detail)

View File

@ -503,7 +503,7 @@ class Field(object):
# attempting to accumulate a list of errors. # attempting to accumulate a list of errors.
if isinstance(exc.detail, dict): if isinstance(exc.detail, dict):
raise raise
errors.append((exc.detail, exc.code)) errors.append(exc.full_details)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(build_error_from_django_validation_error(exc)) errors.extend(build_error_from_django_validation_error(exc))
if errors: if errors:

View File

@ -23,7 +23,9 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import DurationField as ModelDurationField
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.exceptions import build_error_from_django_validation_error from rest_framework.exceptions import (
ErrorDetails, build_error_from_django_validation_error
)
from rest_framework.utils import model_meta from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import ( from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -214,12 +216,12 @@ class BaseSerializer(Field):
self._validated_data = self.run_validation(self.initial_data) self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.full_details self._errors = exc.detail
else: else:
self._errors = {} self._errors = {}
if self._errors and raise_exception: if self._errors and raise_exception:
raise ValidationError(self._errors) raise ValidationError(self.errors)
return not bool(self._errors) return not bool(self._errors)
@ -249,36 +251,7 @@ class BaseSerializer(Field):
if not hasattr(self, '_errors'): if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.' msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg) raise AssertionError(msg)
return self._errors
if isinstance(self._errors, dict):
errors = ReturnDict(serializer=self)
for key, value in self._errors.items():
if isinstance(value, dict):
errors[key] = {}
for key_, value_ in value.items():
message, code = value_[0]
errors[key][key_] = [message]
elif isinstance(value, list):
if isinstance(value[0], tuple):
message, code = value[0]
else:
message = value[0]
if isinstance(message, list):
errors[key] = message
else:
errors[key] = [message]
elif isinstance(value, tuple):
message, code = value
errors[key] = [message]
else:
errors[key] = [value]
elif isinstance(self._errors, list):
errors = ReturnList(self._errors, serializer=self)
else:
# This shouldn't ever happen.
errors = self._errors
return errors
@property @property
def validated_data(self): def validated_data(self):
@ -333,21 +306,21 @@ def get_validation_error_detail(exc):
return { return {
api_settings.NON_FIELD_ERRORS_KEY: error api_settings.NON_FIELD_ERRORS_KEY: error
} }
elif isinstance(exc.full_details, dict): elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
# Here we ensure that all the values are *lists* of errors. # Here we ensure that all the values are *lists* of errors.
return { return {
key: value if isinstance(value, list) else [value] key: value if isinstance(value, list) else [value]
for key, value in exc.full_details.items() for key, value in exc.detail.items()
} }
elif isinstance(exc.full_details, list): elif isinstance(exc.detail, list):
# Errors raised as a list are non-field errors. # Errors raised as a list are non-field errors.
return { return {
api_settings.NON_FIELD_ERRORS_KEY: exc.full_details api_settings.NON_FIELD_ERRORS_KEY: exc.detail
} }
# Errors raised as a string are non-field errors. # Errors raised as a string are non-field errors.
return { return {
api_settings.NON_FIELD_ERRORS_KEY: [exc.full_details] api_settings.NON_FIELD_ERRORS_KEY: [exc.detail]
} }
@ -455,11 +428,11 @@ class Serializer(BaseSerializer):
) )
code = 'invalid' code = 'invalid'
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
}) })
ret = ReturnDict(serializer=self) ret = OrderedDict()
errors = ReturnDict(serializer=self) errors = OrderedDict()
fields = self._writable_fields fields = self._writable_fields
for field in fields: for field in fields:
@ -470,7 +443,7 @@ class Serializer(BaseSerializer):
if validate_method is not None: if validate_method is not None:
validated_value = validate_method(validated_value) validated_value = validate_method(validated_value)
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = exc.full_details errors[field.field_name] = exc.detail
except DjangoValidationError as exc: except DjangoValidationError as exc:
error = build_error_from_django_validation_error(exc) error = build_error_from_django_validation_error(exc)
errors[field.field_name] = error errors[field.field_name] = error
@ -610,14 +583,14 @@ class ListSerializer(BaseSerializer):
) )
code = 'not_a_list' code = 'not_a_list'
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
}) })
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty'] message = self.error_messages['empty']
code = 'empty_not_allowed' code = 'empty_not_allowed'
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)]
}) })
ret = [] ret = []

View File

@ -11,7 +11,7 @@ from __future__ import unicode_literals
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import unicode_to_repr from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ErrorDetails, ValidationError
from rest_framework.utils.representation import smart_repr from rest_framework.utils.representation import smart_repr
@ -102,7 +102,7 @@ class UniqueTogetherValidator(object):
code = 'required' code = 'required'
missing = { missing = {
field_name: [(self.missing_message, code)] field_name: ErrorDetails(self.missing_message, code)
for field_name in self.fields for field_name in self.fields
if field_name not in attrs if field_name not in attrs
} }
@ -150,7 +150,7 @@ class UniqueTogetherValidator(object):
field_names = ', '.join(self.fields) field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names) message = self.message.format(field_names=field_names)
code = 'unique' code = 'unique'
raise ValidationError(message, code=code) raise ValidationError(ErrorDetails(message, code=code))
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
@ -189,7 +189,7 @@ class BaseUniqueForValidator(object):
""" """
code = 'required' code = 'required'
missing = { missing = {
field_name: [(self.missing_message, code)] field_name: ErrorDetails(self.missing_message, code)
for field_name in [self.field, self.date_field] for field_name in [self.field, self.date_field]
if field_name not in attrs if field_name not in attrs
} }
@ -216,7 +216,7 @@ class BaseUniqueForValidator(object):
if queryset.exists(): if queryset.exists():
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
code = 'unique' code = 'unique'
raise ValidationError({self.field: [(message, code)]}) raise ValidationError({self.field: ErrorDetails(message, code)})
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (

View File

@ -31,7 +31,7 @@ class TestValidationErrorWithCode(TestCase):
def exception_handler(exc, request): def exception_handler(exc, request):
return_errors = {} return_errors = {}
for field_name, errors in exc.full_details.items(): for field_name, errors in exc.detail.items():
return_errors[field_name] = [] return_errors[field_name] = []
for message, code in errors: for message, code in errors:
return_errors[field_name].append({ return_errors[field_name].append({