This commit is contained in:
kgeorgy 2017-07-12 16:25:01 +00:00 committed by GitHub
commit 52ef43b0d3
4 changed files with 164 additions and 92 deletions

View File

@ -25,7 +25,9 @@ from django.utils.dateparse import (
) )
from django.utils.duration import duration_string from django.utils.duration import duration_string
from django.utils.encoding import is_protected_type, smart_text from django.utils.encoding import is_protected_type, smart_text
from django.utils.formats import localize_input, sanitize_separators from django.utils.formats import (
localize_input, number_format, sanitize_separators
)
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.ipv6 import clean_ipv6_address from django.utils.ipv6 import clean_ipv6_address
from django.utils.timezone import utc from django.utils.timezone import utc
@ -882,107 +884,28 @@ class IPAddressField(CharField):
# Number types... # Number types...
class IntegerField(Field):
default_error_messages = {
'invalid': _('A valid integer is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_string_length': _('String value too large.')
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2'
def __init__(self, **kwargs): class NumberField(Field):
self.max_value = kwargs.pop('max_value', None)
self.min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if self.max_value is not None:
message = self.error_messages['max_value'].format(max_value=self.max_value)
self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
message = self.error_messages['min_value'].format(min_value=self.min_value)
self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, data):
if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
self.fail('max_string_length')
try:
data = int(self.re_decimal.sub('', str(data)))
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_representation(self, value):
return int(value)
class FloatField(Field):
default_error_messages = { default_error_messages = {
'invalid': _('A valid number is required.'), 'invalid': _('A valid number is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'), 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'), 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_string_length': _('String value too large.') 'max_string_length': _('String value too large.')
} }
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.max_value = kwargs.pop('max_value', None) self.max_value = kwargs.pop('max_value', None)
self.min_value = kwargs.pop('min_value', None) self.min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs) self.localize = kwargs.pop('localize', api_settings.LOCALIZE_NUMBER_FIELDS)
if self.max_value is not None:
message = self.error_messages['max_value'].format(max_value=self.max_value)
self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
message = self.error_messages['min_value'].format(min_value=self.min_value)
self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, data): super(NumberField, self).__init__(**kwargs)
if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
self.fail('max_string_length')
try:
return float(data)
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, value):
return float(value)
class DecimalField(Field):
default_error_messages = {
'invalid': _('A valid number is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'),
'max_string_length': _('String value too large.')
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
localize=False, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.localize = localize
if coerce_to_string is not None:
self.coerce_to_string = coerce_to_string
if self.localize: if self.localize:
self.coerce_to_string = True self.coerce_to_string = True
self.max_value = max_value
self.min_value = min_value
if self.max_digits is not None and self.decimal_places is not None:
self.max_whole_digits = self.max_digits - self.decimal_places
else:
self.max_whole_digits = None
super(DecimalField, self).__init__(**kwargs)
if self.max_value is not None: if self.max_value is not None:
message = self.error_messages['max_value'].format(max_value=self.max_value) message = self.error_messages['max_value'].format(max_value=self.max_value)
self.validators.append(MaxValueValidator(self.max_value, message=message)) self.validators.append(MaxValueValidator(self.max_value, message=message))
@ -991,10 +914,8 @@ class DecimalField(Field):
self.validators.append(MinValueValidator(self.min_value, message=message)) self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, data): def to_internal_value(self, data):
""" if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
Validate that the input is a decimal number and return a Decimal self.fail('max_string_length')
instance.
"""
data = smart_text(data).strip() data = smart_text(data).strip()
@ -1004,6 +925,92 @@ class DecimalField(Field):
if len(data) > self.MAX_STRING_LENGTH: if len(data) > self.MAX_STRING_LENGTH:
self.fail('max_string_length') self.fail('max_string_length')
return data
def to_representation(self, value):
return super(NumberField, self).to_representation(value)
class IntegerField(NumberField):
default_error_messages = NumberField.default_error_messages.copy()
default_error_messages.update({
'invalid': _('A valid integer is required.'),
})
re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2'
def __init__(self, **kwargs):
self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_INTEGER_TO_STRING)
super(IntegerField, self).__init__(**kwargs)
def to_internal_value(self, data):
data = super(IntegerField, self).to_internal_value(data)
try:
data = int(self.re_decimal.sub('', str(data)))
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_representation(self, value):
if self.localize:
return number_format(value)
if self.coerce_to_string:
return str(int(value))
return int(value)
class FloatField(NumberField):
def __init__(self, **kwargs):
self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_FLOAT_TO_STRING)
super(FloatField, self).__init__(**kwargs)
def to_internal_value(self, data):
data = super(FloatField, self).to_internal_value(data)
try:
return float(data)
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, value):
if self.localize:
return number_format(value)
if self.coerce_to_string:
return str(float(value))
return float(value)
class DecimalField(NumberField):
default_error_messages = NumberField.default_error_messages.copy()
default_error_messages.update({
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'),
})
def __init__(self, max_digits, decimal_places, **kwargs):
self.coerce_to_string = kwargs.pop('coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING)
super(DecimalField, self).__init__(**kwargs)
self.max_digits = max_digits
self.decimal_places = decimal_places
if self.max_digits is not None and self.decimal_places is not None:
self.max_whole_digits = self.max_digits - self.decimal_places
else:
self.max_whole_digits = None
def to_internal_value(self, data):
"""
Validate that the input is a decimal number and return a Decimal
instance.
"""
data = super(DecimalField, self).to_internal_value(data)
try: try:
value = decimal.Decimal(data) value = decimal.Decimal(data)
except decimal.DecimalException: except decimal.DecimalException:
@ -1056,14 +1063,13 @@ class DecimalField(Field):
return value return value
def to_representation(self, value): def to_representation(self, value):
coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING)
if not isinstance(value, decimal.Decimal): if not isinstance(value, decimal.Decimal):
value = decimal.Decimal(six.text_type(value).strip()) value = decimal.Decimal(six.text_type(value).strip())
quantized = self.quantize(value) quantized = self.quantize(value)
if not coerce_to_string: if not self.coerce_to_string:
return quantized return quantized
if self.localize: if self.localize:
return localize_input(quantized) return localize_input(quantized)

View File

@ -56,8 +56,8 @@ from rest_framework.fields import ( # NOQA # isort:skip
BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField, BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField,
DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField, DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField,
HiddenField, IPAddressField, ImageField, IntegerField, JSONField, ListField, HiddenField, IPAddressField, ImageField, IntegerField, JSONField, ListField,
ModelField, MultipleChoiceField, NullBooleanField, ReadOnlyField, RegexField, ModelField, MultipleChoiceField, NullBooleanField, NumberField, ReadOnlyField,
SerializerMethodField, SlugField, TimeField, URLField, UUIDField, RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField,
) )
from rest_framework.relations import ( # NOQA # isort:skip from rest_framework.relations import ( # NOQA # isort:skip
HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField, HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField,

View File

@ -111,8 +111,13 @@ DEFAULTS = {
'UNICODE_JSON': True, 'UNICODE_JSON': True,
'COMPACT_JSON': True, 'COMPACT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True, 'COERCE_DECIMAL_TO_STRING': True,
'COERCE_FLOAT_TO_STRING': False,
'COERCE_INTEGER_TO_STRING': False,
'UPLOADED_FILES_USE_URL': True, 'UPLOADED_FILES_USE_URL': True,
# Number fields localization
'LOCALIZE_NUMBER_FIELDS': False,
# Browseable API # Browseable API
'HTML_SELECT_CUTOFF': 1000, 'HTML_SELECT_CUTOFF': 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",

View File

@ -886,6 +886,37 @@ class TestMinMaxIntegerField(FieldValues):
field = serializers.IntegerField(min_value=1, max_value=3) field = serializers.IntegerField(min_value=1, max_value=3)
class TestCoerceToStringIntegerField(FieldValues):
valid_inputs = {}
invalid_inputs = {}
outputs = {
'1': '1',
'0': '0',
1: '1',
0: '0',
1.0: '1',
0.0: '0',
}
field = serializers.IntegerField(coerce_to_string=True)
class TestLocalizedIntegerField(TestCase):
@override_settings(USE_L10N=True, LANGUAGE_CODE='it')
def test_to_internal_value(self):
field = serializers.IntegerField(localize=True)
self.assertEqual(field.to_internal_value('1,0'), 1)
@override_settings(USE_L10N=True, LANGUAGE_CODE=None, DECIMAL_SEPARATOR=',', THOUSAND_SEPARATOR='\'',
NUMBER_GROUPING=3, USE_THOUSAND_SEPARATOR=True)
def test_to_representation(self):
field = serializers.IntegerField(localize=True)
self.assertEqual(field.to_representation(1000), '1\'000')
def test_localize_forces_coerce_to_string(self):
field = serializers.IntegerField(localize=True)
self.assertTrue(isinstance(field.to_representation(3), six.string_types))
class TestFloatField(FieldValues): class TestFloatField(FieldValues):
""" """
Valid and invalid values for `FloatField`. Valid and invalid values for `FloatField`.
@ -934,6 +965,36 @@ class TestMinMaxFloatField(FieldValues):
field = serializers.FloatField(min_value=1, max_value=3) field = serializers.FloatField(min_value=1, max_value=3)
class TestCoerceToStringFloatField(FieldValues):
valid_inputs = {}
invalid_inputs = {}
outputs = {
'1': str(1.0),
'0': str(0.0),
1: str(1.0),
0: str(0.0),
1.5: str(1.5),
}
field = serializers.FloatField(coerce_to_string=True)
class TestLocalizedFloatField(TestCase):
@override_settings(USE_L10N=True, LANGUAGE_CODE='it')
def test_to_internal_value(self):
field = serializers.FloatField(localize=True)
self.assertEqual(field.to_internal_value('1,5'), 1.5)
@override_settings(USE_L10N=True, LANGUAGE_CODE=None, DECIMAL_SEPARATOR=',', THOUSAND_SEPARATOR='\'',
NUMBER_GROUPING=3, USE_THOUSAND_SEPARATOR=True)
def test_to_representation(self):
field = serializers.FloatField(localize=True)
self.assertEqual(field.to_representation(1000.75), '1\'000,75')
def test_localize_forces_coerce_to_string(self):
field = serializers.FloatField(localize=True)
self.assertTrue(isinstance(field.to_representation(3), six.string_types))
class TestDecimalField(FieldValues): class TestDecimalField(FieldValues):
""" """
Valid and invalid values for `DecimalField`. Valid and invalid values for `DecimalField`.