From 38a1b3ec6b62ea0e5bfcb0de2043067d8e333c95 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Aug 2015 09:51:00 +0100 Subject: [PATCH] Rationalize decimal logic. Closes #3222. --- rest_framework/fields.py | 38 ++++++++++++++++++++++++-------------- tests/test_fields.py | 6 ++++-- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index aa264b2aa..37c902954 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -797,6 +797,11 @@ class DecimalField(Field): self.max_value = max_value self.min_value = min_value + if self.max_digits is not None and self.decimal_places is not None: + self.max_whole_digits = self.max_digits - self.decimal_places + else: + self.max_whole_digits = None + super(DecimalField, self).__init__(**kwargs) if self.max_value is not None: @@ -840,24 +845,29 @@ class DecimalField(Field): values or to enhance it in any way you need to. """ sign, digittuple, exponent = value.as_tuple() - decimals = exponent * decimal.Decimal(-1) if exponent < 0 else 0 - # digittuple doesn't include any leading zeros. - digits = len(digittuple) - if decimals > digits: - # We have leading zeros up to or past the decimal point. Count - # everything past the decimal point as a digit. We do not count - # 0 before the decimal point as a digit since that would mean - # we would not allow max_digits = decimal_places. - digits = decimals - whole_digits = digits - decimals + if exponent >= 0: + # 1234500.0 + total_digits = len(digittuple) + exponent + whole_digits = total_digits + decimal_places = 0 + elif len(digittuple) > abs(exponent): + # 123.45 + total_digits = len(digittuple) + whole_digits = total_digits - abs(exponent) + decimal_places = abs(exponent) + else: + # 0.001234 + total_digits = abs(exponent) + whole_digits = 0 + decimal_places = total_digits - if self.max_digits is not None and digits > self.max_digits: + if self.max_digits is not None and total_digits > self.max_digits: self.fail('max_digits', max_digits=self.max_digits) - if self.decimal_places is not None and decimals > self.decimal_places: + if self.decimal_places is not None and decimal_places > self.decimal_places: self.fail('max_decimal_places', max_decimal_places=self.decimal_places) - if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): - self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places) + if self.max_whole_digits is not None and whole_digits > self.max_whole_digits: + self.fail('max_whole_digits', max_whole_digits=self.max_whole_digits) return value diff --git a/tests/test_fields.py b/tests/test_fields.py index 042787357..6adba66a7 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -766,15 +766,17 @@ class TestDecimalField(FieldValues): 0: Decimal('0'), 12.3: Decimal('12.3'), 0.1: Decimal('0.1'), - '2E+2': Decimal('200'), + '2E+1': Decimal('20'), } invalid_inputs = ( ('abc', ["A valid number is required."]), (Decimal('Nan'), ["A valid number is required."]), (Decimal('Inf'), ["A valid number is required."]), ('12.345', ["Ensure that there are no more than 3 digits in total."]), + (200000000000.0, ["Ensure that there are no more than 3 digits in total."]), ('0.01', ["Ensure that there are no more than 1 decimal places."]), - (123, ["Ensure that there are no more than 2 digits before the decimal point."]) + (123, ["Ensure that there are no more than 2 digits before the decimal point."]), + ('2E+2', ["Ensure that there are no more than 2 digits before the decimal point."]) ) outputs = { '1': '1.0',