From 7943429dabe7dbda33798e1ef8e8d19e1670430e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Oct 2016 18:04:58 +0100 Subject: [PATCH] Rejig ErrorMessages, to prefer code= directly on ValidationError --- rest_framework/exceptions.py | 20 ++++++-------------- rest_framework/serializers.py | 14 +++++++------- rest_framework/validators.py | 14 +++++++------- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 994a942e4..056af4fa6 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -10,7 +10,6 @@ import math from django.utils import six from django.utils.encoding import force_text -from django.utils.functional import Promise from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext @@ -18,21 +17,21 @@ from rest_framework import status from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList -def _force_text_recursive(data): +def _force_text_recursive(data, code=None): """ Descend into a nested data structure, forcing any - lazy translation strings into plain text. + lazy translation strings or strings into `ErrorMessage`. """ if isinstance(data, list): ret = [ - _force_text_recursive(item) for item in data + _force_text_recursive(item, code) for item in data ] if isinstance(data, ReturnList): return ReturnList(ret, serializer=data.serializer) return ret elif isinstance(data, dict): ret = { - key: _force_text_recursive(value) + key: _force_text_recursive(value, code) for key, value in data.items() } if isinstance(data, ReturnDict): @@ -40,7 +39,7 @@ def _force_text_recursive(data): return ret text = force_text(data) - code = getattr(data, 'code', 'invalid') + code = getattr(data, 'code', code or 'invalid') return ErrorMessage(text, code) @@ -82,18 +81,11 @@ class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST def __init__(self, detail, code=None): - if code is not None: - assert isinstance(detail, six.string_types + (Promise,)), ( - "When providing a 'code', the detail must be a string argument. " - "Use 'ErrorMessage' to set the code for a composite ValidationError" - ) - detail = ErrorMessage(detail, code) - # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - self.detail = _force_text_recursive(detail) + self.detail = _force_text_recursive(detail, code=code) def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index c568ab265..10d3c706d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -423,8 +423,8 @@ class Serializer(BaseSerializer): datatype=type(data).__name__ ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')] - }) + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='invalid') ret = OrderedDict() errors = OrderedDict() @@ -440,7 +440,7 @@ class Serializer(BaseSerializer): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = get_validation_error_detail(exc) except SkipField: pass else: @@ -580,14 +580,14 @@ class ListSerializer(BaseSerializer): input_type=type(data).__name__ ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')] - }) + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='not_a_list') if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')] - }) + api_settings.NON_FIELD_ERRORS_KEY: [message] + }, code='empty') ret = [] errors = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 386aff9c4..57f8bad53 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.db import DataError from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ErrorMessage, ValidationError +from rest_framework.exceptions import ValidationError from rest_framework.utils.representation import smart_repr @@ -121,12 +121,12 @@ class UniqueTogetherValidator(object): return missing_items = { - field_name: ErrorMessage(self.missing_message, code='required') + field_name: self.missing_message for field_name in self.fields if field_name not in attrs } if missing_items: - raise ValidationError(missing_items) + raise ValidationError(missing_items, code='required') def filter_queryset(self, attrs, queryset): """ @@ -206,12 +206,12 @@ class BaseUniqueForValidator(object): 'required' state on the fields they are applied to. """ missing_items = { - field_name: ErrorMessage(self.missing_message, code='required') + field_name: self.missing_message for field_name in [self.field, self.date_field] if field_name not in attrs } if missing_items: - raise ValidationError(missing_items) + raise ValidationError(missing_items, code='required') def filter_queryset(self, attrs, queryset): raise NotImplementedError('`filter_queryset` must be implemented.') @@ -233,8 +233,8 @@ class BaseUniqueForValidator(object): if qs_exists(queryset): message = self.message.format(date_field=self.date_field) raise ValidationError({ - self.field: ErrorMessage(message, code='unique') - }) + self.field: message + }, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (