diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c8f65db0e..612568e57 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1014,10 +1014,11 @@ class DecimalField(Field): MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, - localize=False, rounding=None, **kwargs): + localize=False, rounding=None, normalize=False, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.localize = localize + self.normalize = normalize if coerce_to_string is not None: self.coerce_to_string = coerce_to_string if self.localize: @@ -1125,6 +1126,11 @@ class DecimalField(Field): quantized = self.quantize(value) + # TODO: Should maybe name the value to something not bound to + # quantized. Ex: out_value + if self.normalize: + quantized = quantized.normalize() + if not coerce_to_string: return quantized if self.localize: diff --git a/tests/test_fields.py b/tests/test_fields.py index 12c936b22..103940ea7 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1169,6 +1169,26 @@ class TestQuantizedValueForDecimal(TestCase): expected_digit_tuple = (0, (1, 2, 0, 0), -2) assert value == expected_digit_tuple +class TestNormalizedValueDecimalField(TestCase): + """ + Test that we get the expected behavior of on DecimalField when normalize=True + """ + + def test_normalize_output(self): + field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=True) + output = field.to_representation(Decimal('1.000')) + assert output == '1' + + def test_non_normalize_output(self): + field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=False) + output = field.to_representation(Decimal('1.000')) + assert output == '1.000' + + def test_normalize_coeherce_to_string(self): + field = serializers.DecimalField(max_digits=4, decimal_places=3, normalize=True, coerce_to_string=False) + output = field.to_representation(Decimal('1.000')) + assert output == Decimal('1') + class TestNoDecimalPlaces(FieldValues): valid_inputs = {