Added normalize parameter to DecimalField to be able to strip trailing zeros. Fixes #6151.

This commit is contained in:
Henrik Palmlund Wahlgren 2019-03-19 17:04:10 +01:00
parent d2d1888217
commit 7c1414c45d
2 changed files with 27 additions and 1 deletions

View File

@ -1014,10 +1014,11 @@ class DecimalField(Field):
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
localize=False, rounding=None, **kwargs): localize=False, rounding=None, normalize=False, **kwargs):
self.max_digits = max_digits self.max_digits = max_digits
self.decimal_places = decimal_places self.decimal_places = decimal_places
self.localize = localize self.localize = localize
self.normalize = normalize
if coerce_to_string is not None: if coerce_to_string is not None:
self.coerce_to_string = coerce_to_string self.coerce_to_string = coerce_to_string
if self.localize: if self.localize:
@ -1125,6 +1126,11 @@ class DecimalField(Field):
quantized = self.quantize(value) quantized = self.quantize(value)
# TODO: Should maybe name the value to something not bound to
# quantized. Ex: out_value
if self.normalize:
quantized = quantized.normalize()
if not coerce_to_string: if not coerce_to_string:
return quantized return quantized
if self.localize: if self.localize:

View File

@ -1169,6 +1169,26 @@ class TestQuantizedValueForDecimal(TestCase):
expected_digit_tuple = (0, (1, 2, 0, 0), -2) expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple assert value == expected_digit_tuple
class TestNormalizedValueDecimalField(TestCase):
"""
Test that we get the expected behavior of on DecimalField when normalize=True
"""
def test_normalize_output(self):
field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=True)
output = field.to_representation(Decimal('1.000'))
assert output == '1'
def test_non_normalize_output(self):
field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=False)
output = field.to_representation(Decimal('1.000'))
assert output == '1.000'
def test_normalize_coeherce_to_string(self):
field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=True, coerce_to_string=False)
output = field.to_representation(Decimal('1.000'))
assert output == Decimal('1')
class TestNoDecimalPlaces(FieldValues): class TestNoDecimalPlaces(FieldValues):
valid_inputs = { valid_inputs = {