diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 14b264ff9..2c36df5fc 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -25,7 +25,9 @@ from django.utils.dateparse import ( ) from django.utils.duration import duration_string from django.utils.encoding import is_protected_type, smart_text -from django.utils.formats import localize_input, sanitize_separators +from django.utils.formats import ( + localize_input, number_format, sanitize_separators +) from django.utils.functional import cached_property from django.utils.ipv6 import clean_ipv6_address from django.utils.timezone import utc @@ -882,107 +884,28 @@ class IPAddressField(CharField): # Number types... -class IntegerField(Field): - default_error_messages = { - 'invalid': _('A valid integer is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_string_length': _('String value too large.') - } - MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' - def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value', None) - self.min_value = kwargs.pop('min_value', None) - super(IntegerField, self).__init__(**kwargs) - if self.max_value is not None: - message = self.error_messages['max_value'].format(max_value=self.max_value) - self.validators.append(MaxValueValidator(self.max_value, message=message)) - if self.min_value is not None: - message = self.error_messages['min_value'].format(min_value=self.min_value) - self.validators.append(MinValueValidator(self.min_value, message=message)) +class NumberField(Field): - def to_internal_value(self, data): - if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: - self.fail('max_string_length') - - try: - data = int(self.re_decimal.sub('', str(data))) - except (ValueError, TypeError): - self.fail('invalid') - return data - - def to_representation(self, value): - return int(value) - - -class FloatField(Field): default_error_messages = { 'invalid': _('A valid number is required.'), 'max_value': _('Ensure this value is less than or equal to {max_value}.'), 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), 'max_string_length': _('String value too large.') } + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. def __init__(self, **kwargs): self.max_value = kwargs.pop('max_value', None) self.min_value = kwargs.pop('min_value', None) - super(FloatField, self).__init__(**kwargs) - if self.max_value is not None: - message = self.error_messages['max_value'].format(max_value=self.max_value) - self.validators.append(MaxValueValidator(self.max_value, message=message)) - if self.min_value is not None: - message = self.error_messages['min_value'].format(min_value=self.min_value) - self.validators.append(MinValueValidator(self.min_value, message=message)) + self.localize = kwargs.pop('localize', api_settings.LOCALIZE_NUMBER_FIELDS) - def to_internal_value(self, data): + super(NumberField, self).__init__(**kwargs) - if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: - self.fail('max_string_length') - - try: - return float(data) - except (TypeError, ValueError): - self.fail('invalid') - - def to_representation(self, value): - return float(value) - - -class DecimalField(Field): - default_error_messages = { - 'invalid': _('A valid number is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), - 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), - 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), - 'max_string_length': _('String value too large.') - } - MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - - def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, - localize=False, **kwargs): - self.max_digits = max_digits - self.decimal_places = decimal_places - self.localize = localize - if coerce_to_string is not None: - self.coerce_to_string = coerce_to_string if self.localize: self.coerce_to_string = True - self.max_value = max_value - self.min_value = min_value - - if self.max_digits is not None and self.decimal_places is not None: - self.max_whole_digits = self.max_digits - self.decimal_places - else: - self.max_whole_digits = None - - super(DecimalField, self).__init__(**kwargs) - if self.max_value is not None: message = self.error_messages['max_value'].format(max_value=self.max_value) self.validators.append(MaxValueValidator(self.max_value, message=message)) @@ -991,10 +914,8 @@ class DecimalField(Field): self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): - """ - Validate that the input is a decimal number and return a Decimal - instance. - """ + if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: + self.fail('max_string_length') data = smart_text(data).strip() @@ -1004,6 +925,92 @@ class DecimalField(Field): if len(data) > self.MAX_STRING_LENGTH: self.fail('max_string_length') + return data + + def to_representation(self, value): + return super(NumberField, self).to_representation(value) + + +class IntegerField(NumberField): + default_error_messages = NumberField.default_error_messages.copy() + default_error_messages.update({ + 'invalid': _('A valid integer is required.'), + }) + + re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' + + def __init__(self, **kwargs): + self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_INTEGER_TO_STRING) + super(IntegerField, self).__init__(**kwargs) + + def to_internal_value(self, data): + + data = super(IntegerField, self).to_internal_value(data) + + try: + data = int(self.re_decimal.sub('', str(data))) + except (ValueError, TypeError): + self.fail('invalid') + return data + + def to_representation(self, value): + if self.localize: + return number_format(value) + if self.coerce_to_string: + return str(int(value)) + return int(value) + + +class FloatField(NumberField): + + def __init__(self, **kwargs): + self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_FLOAT_TO_STRING) + super(FloatField, self).__init__(**kwargs) + + def to_internal_value(self, data): + + data = super(FloatField, self).to_internal_value(data) + + try: + return float(data) + except (TypeError, ValueError): + self.fail('invalid') + + def to_representation(self, value): + if self.localize: + return number_format(value) + if self.coerce_to_string: + return str(float(value)) + return float(value) + + +class DecimalField(NumberField): + default_error_messages = NumberField.default_error_messages.copy() + default_error_messages.update({ + 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), + }) + + def __init__(self, max_digits, decimal_places, **kwargs): + self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING) + super(DecimalField, self).__init__(**kwargs) + + self.max_digits = max_digits + self.decimal_places = decimal_places + + if self.max_digits is not None and self.decimal_places is not None: + self.max_whole_digits = self.max_digits - self.decimal_places + else: + self.max_whole_digits = None + + def to_internal_value(self, data): + """ + Validate that the input is a decimal number and return a Decimal + instance. + """ + + data = super(DecimalField, self).to_internal_value(data) try: value = decimal.Decimal(data) except decimal.DecimalException: @@ -1056,14 +1063,13 @@ class DecimalField(Field): return value def to_representation(self, value): - coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING) if not isinstance(value, decimal.Decimal): value = decimal.Decimal(six.text_type(value).strip()) quantized = self.quantize(value) - if not coerce_to_string: + if not self.coerce_to_string: return quantized if self.localize: return localize_input(quantized) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a4b51ae9d..956cc68d6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -56,8 +56,8 @@ from rest_framework.fields import ( # NOQA # isort:skip BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField, DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField, HiddenField, IPAddressField, ImageField, IntegerField, JSONField, ListField, - ModelField, MultipleChoiceField, NullBooleanField, ReadOnlyField, RegexField, - SerializerMethodField, SlugField, TimeField, URLField, UUIDField, + ModelField, MultipleChoiceField, NullBooleanField, NumberField, ReadOnlyField, + RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField, ) from rest_framework.relations import ( # NOQA # isort:skip HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField, diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 3f3c9110a..1b61c32cd 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,8 +111,13 @@ DEFAULTS = { 'UNICODE_JSON': True, 'COMPACT_JSON': True, 'COERCE_DECIMAL_TO_STRING': True, + 'COERCE_FLOAT_TO_STRING': False, + 'COERCE_INTEGER_TO_STRING': False, 'UPLOADED_FILES_USE_URL': True, + # Number fields localization + 'LOCALIZE_NUMBER_FIELDS': False, + # Browseable API 'HTML_SELECT_CUTOFF': 1000, 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", diff --git a/tests/test_fields.py b/tests/test_fields.py index 38dc5f7a7..92ef02993 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -886,6 +886,37 @@ class TestMinMaxIntegerField(FieldValues): field = serializers.IntegerField(min_value=1, max_value=3) +class TestCoerceToStringIntegerField(FieldValues): + valid_inputs = {} + invalid_inputs = {} + outputs = { + '1': '1', + '0': '0', + 1: '1', + 0: '0', + 1.0: '1', + 0.0: '0', + } + field = serializers.IntegerField(coerce_to_string=True) + + +class TestLocalizedIntegerField(TestCase): + @override_settings(USE_L10N=True, LANGUAGE_CODE='it') + def test_to_internal_value(self): + field = serializers.IntegerField(localize=True) + self.assertEqual(field.to_internal_value('1,0'), 1) + + @override_settings(USE_L10N=True, LANGUAGE_CODE=None, DECIMAL_SEPARATOR=',', THOUSAND_SEPARATOR='\'', + NUMBER_GROUPING=3, USE_THOUSAND_SEPARATOR=True) + def test_to_representation(self): + field = serializers.IntegerField(localize=True) + self.assertEqual(field.to_representation(1000), '1\'000') + + def test_localize_forces_coerce_to_string(self): + field = serializers.IntegerField(localize=True) + self.assertTrue(isinstance(field.to_representation(3), six.string_types)) + + class TestFloatField(FieldValues): """ Valid and invalid values for `FloatField`. @@ -934,6 +965,36 @@ class TestMinMaxFloatField(FieldValues): field = serializers.FloatField(min_value=1, max_value=3) +class TestCoerceToStringFloatField(FieldValues): + valid_inputs = {} + invalid_inputs = {} + outputs = { + '1': str(1.0), + '0': str(0.0), + 1: str(1.0), + 0: str(0.0), + 1.5: str(1.5), + } + field = serializers.FloatField(coerce_to_string=True) + + +class TestLocalizedFloatField(TestCase): + @override_settings(USE_L10N=True, LANGUAGE_CODE='it') + def test_to_internal_value(self): + field = serializers.FloatField(localize=True) + self.assertEqual(field.to_internal_value('1,5'), 1.5) + + @override_settings(USE_L10N=True, LANGUAGE_CODE=None, DECIMAL_SEPARATOR=',', THOUSAND_SEPARATOR='\'', + NUMBER_GROUPING=3, USE_THOUSAND_SEPARATOR=True) + def test_to_representation(self): + field = serializers.FloatField(localize=True) + self.assertEqual(field.to_representation(1000.75), '1\'000,75') + + def test_localize_forces_coerce_to_string(self): + field = serializers.FloatField(localize=True) + self.assertTrue(isinstance(field.to_representation(3), six.string_types)) + + class TestDecimalField(FieldValues): """ Valid and invalid values for `DecimalField`.