From 2344d2270a851acea6a3089b9d70b92df3fd7ee2 Mon Sep 17 00:00:00 2001 From: Jonathan Liuti Date: Tue, 22 Dec 2015 16:02:24 +0100 Subject: [PATCH] Deal with ValidationError instantiation --- rest_framework/exceptions.py | 76 +++++++++++++++++----------------- rest_framework/fields.py | 2 +- rest_framework/serializers.py | 61 ++++++++------------------- rest_framework/validators.py | 10 ++--- tests/test_validation_error.py | 2 +- 5 files changed, 63 insertions(+), 88 deletions(-) diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 2d1474050..5f1b27c76 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -7,6 +7,7 @@ In addition Django's built in 403 and 404 exceptions are handled. from __future__ import unicode_literals import math +from collections import namedtuple from django.utils import six from django.utils.encoding import force_text @@ -61,7 +62,7 @@ class APIException(Exception): def build_error_from_django_validation_error(exc_info): code = getattr(exc_info, 'code', None) or 'invalid' return [ - (msg, code) + ErrorDetails(msg, code) for msg in exc_info.messages ] @@ -72,60 +73,61 @@ def build_error_from_django_validation_error(exc_info): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') +ErrorDetails = namedtuple('ErrorDetails', ['message', 'code']) + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST code = None def __init__(self, detail, code=None): - # For validation errors the 'detail' key is always required. - # The details should always be coerced to a list if not already. - if code: - self.full_details = [(detail, code)] + self.full_details = ErrorDetails(detail, code) else: self.full_details = detail - if isinstance(self.full_details, tuple): - self.detail, self.code = self.full_details - self.detail = [self.detail] + if not isinstance(self.full_details, dict) \ + and not isinstance(self.full_details, list): + self.full_details = [self.full_details] + self.full_details = _force_text_recursive(self.full_details) - elif isinstance(self.full_details, list): + self.detail = detail + if isinstance(self.full_details, list): if isinstance(self.full_details, ReturnList): - self.detail = ReturnList(serializer=self.full_details.serializer) + self.detail = ReturnList( + serializer=self.full_details.serializer) else: self.detail = [] - for error in self.full_details: - if isinstance(error, tuple): - message, code = error - self.detail.append(message) - elif isinstance(error, dict): - self.detail = self.full_details - break - + for full_detail in self.full_details: + if isinstance(full_detail, ErrorDetails): + self.detail.append(full_detail.message) + elif isinstance(full_detail, dict): + if not full_detail: + self.detail.append(full_detail) + for key, value in full_detail.items(): + if isinstance(value, list): + self.detail.append( + {key: [item.message] + if isinstance(item, ErrorDetails) + else [item] for item in value}) + elif isinstance(full_detail, list): + self.detail.extend(full_detail) + else: + self.detail.append(full_detail) elif isinstance(self.full_details, dict): if isinstance(self.full_details, ReturnDict): - self.detail = ReturnDict(serializer=self.full_details.serializer) + self.detail = ReturnDict( + serializer=self.full_details.serializer) else: self.detail = {} - - for field_name, errors in self.full_details.items(): - self.detail[field_name] = [] - if isinstance(errors, tuple): - message, code = errors - self.detail[field_name].append(message) - elif isinstance(errors, list): - for error in errors: - if isinstance(error, tuple): - message, code = error - else: - message = error - if message: - self.detail[field_name].append(message) - else: - self.detail = [self.full_details] - - self.detail = _force_text_recursive(self.detail) + for field_name, full_detail in self.full_details.items(): + if isinstance(full_detail, list): + self.detail[field_name] = [ + item.message if isinstance(item, ErrorDetails) else item + for item in full_detail + ] + else: + self.detail[field_name] = full_detail def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 35df9225c..1f588e990 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -503,7 +503,7 @@ class Field(object): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.append((exc.detail, exc.code)) + errors.append(exc.full_details) except DjangoValidationError as exc: errors.extend(build_error_from_django_validation_error(exc)) if errors: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 27e764da3..77c5cd182 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,7 +23,9 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr -from rest_framework.exceptions import build_error_from_django_validation_error +from rest_framework.exceptions import ( + ErrorDetails, build_error_from_django_validation_error +) from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -214,12 +216,12 @@ class BaseSerializer(Field): self._validated_data = self.run_validation(self.initial_data) except ValidationError as exc: self._validated_data = {} - self._errors = exc.full_details + self._errors = exc.detail else: self._errors = {} if self._errors and raise_exception: - raise ValidationError(self._errors) + raise ValidationError(self.errors) return not bool(self._errors) @@ -249,36 +251,7 @@ class BaseSerializer(Field): if not hasattr(self, '_errors'): msg = 'You must call `.is_valid()` before accessing `.errors`.' raise AssertionError(msg) - - if isinstance(self._errors, dict): - errors = ReturnDict(serializer=self) - for key, value in self._errors.items(): - if isinstance(value, dict): - errors[key] = {} - for key_, value_ in value.items(): - message, code = value_[0] - errors[key][key_] = [message] - elif isinstance(value, list): - if isinstance(value[0], tuple): - message, code = value[0] - else: - message = value[0] - if isinstance(message, list): - errors[key] = message - else: - errors[key] = [message] - elif isinstance(value, tuple): - message, code = value - errors[key] = [message] - else: - errors[key] = [value] - elif isinstance(self._errors, list): - errors = ReturnList(self._errors, serializer=self) - else: - # This shouldn't ever happen. - errors = self._errors - - return errors + return self._errors @property def validated_data(self): @@ -333,21 +306,21 @@ def get_validation_error_detail(exc): return { api_settings.NON_FIELD_ERRORS_KEY: error } - elif isinstance(exc.full_details, dict): + elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. # Here we ensure that all the values are *lists* of errors. return { key: value if isinstance(value, list) else [value] - for key, value in exc.full_details.items() + for key, value in exc.detail.items() } - elif isinstance(exc.full_details, list): + elif isinstance(exc.detail, list): # Errors raised as a list are non-field errors. return { - api_settings.NON_FIELD_ERRORS_KEY: exc.full_details + api_settings.NON_FIELD_ERRORS_KEY: exc.detail } # Errors raised as a string are non-field errors. return { - api_settings.NON_FIELD_ERRORS_KEY: [exc.full_details] + api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] } @@ -455,11 +428,11 @@ class Serializer(BaseSerializer): ) code = 'invalid' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) - ret = ReturnDict(serializer=self) - errors = ReturnDict(serializer=self) + ret = OrderedDict() + errors = OrderedDict() fields = self._writable_fields for field in fields: @@ -470,7 +443,7 @@ class Serializer(BaseSerializer): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc.full_details + errors[field.field_name] = exc.detail except DjangoValidationError as exc: error = build_error_from_django_validation_error(exc) errors[field.field_name] = error @@ -610,14 +583,14 @@ class ListSerializer(BaseSerializer): ) code = 'not_a_list' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] code = 'empty_not_allowed' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [(message, code)] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 8fe629c72..07c8eb464 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -11,7 +11,7 @@ from __future__ import unicode_literals from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ErrorDetails, ValidationError from rest_framework.utils.representation import smart_repr @@ -102,7 +102,7 @@ class UniqueTogetherValidator(object): code = 'required' missing = { - field_name: [(self.missing_message, code)] + field_name: ErrorDetails(self.missing_message, code) for field_name in self.fields if field_name not in attrs } @@ -150,7 +150,7 @@ class UniqueTogetherValidator(object): field_names = ', '.join(self.fields) message = self.message.format(field_names=field_names) code = 'unique' - raise ValidationError(message, code=code) + raise ValidationError(ErrorDetails(message, code=code)) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -189,7 +189,7 @@ class BaseUniqueForValidator(object): """ code = 'required' missing = { - field_name: [(self.missing_message, code)] + field_name: ErrorDetails(self.missing_message, code) for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -216,7 +216,7 @@ class BaseUniqueForValidator(object): if queryset.exists(): message = self.message.format(date_field=self.date_field) code = 'unique' - raise ValidationError({self.field: [(message, code)]}) + raise ValidationError({self.field: ErrorDetails(message, code)}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py index fbde19022..7d2ec1f8d 100644 --- a/tests/test_validation_error.py +++ b/tests/test_validation_error.py @@ -31,7 +31,7 @@ class TestValidationErrorWithCode(TestCase): def exception_handler(exc, request): return_errors = {} - for field_name, errors in exc.full_details.items(): + for field_name, errors in exc.detail.items(): return_errors[field_name] = [] for message, code in errors: return_errors[field_name].append({