From e8b1767629966e771f820c66e5a7d7417e44306f Mon Sep 17 00:00:00 2001 From: Trang Tran Date: Thu, 26 Oct 2017 08:03:35 +0700 Subject: [PATCH 1/3] Adding rounding parameter to DecimalField. --- rest_framework/fields.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 43fed9aee..cbc8586c4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -990,12 +990,13 @@ class DecimalField(Field): 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), - 'max_string_length': _('String value too large.') + 'max_string_length': _('String value too large.'), + 'invalid_rounding': _('Invalid rounding option {rounding}. Valid values for rounding are: {valid_roundings}') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, - localize=False, **kwargs): + localize=False, rounding=None, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.localize = localize @@ -1027,6 +1028,11 @@ class DecimalField(Field): self.validators.append( MinValueValidator(self.min_value, message=message)) + valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] + if rounding is not None and rounding not in valid_roundings: + self.fail('invalid_rounding', rounding=rounding, valid_roundings=valid_roundings) + self.rounding = rounding + def to_internal_value(self, data): """ Validate that the input is a decimal number and return a Decimal @@ -1119,6 +1125,7 @@ class DecimalField(Field): context.prec = self.max_digits return value.quantize( decimal.Decimal('.1') ** self.decimal_places, + rounding=self.rounding, context=context ) From a1b5820b7c52cb2e19924f0966642abf44feb722 Mon Sep 17 00:00:00 2001 From: Trang Tran Date: Thu, 26 Oct 2017 21:17:49 +0700 Subject: [PATCH 2/3] Using standard `assert` instead of `self.fail()`. --- rest_framework/fields.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cbc8586c4..1e7981ecc 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -990,8 +990,7 @@ class DecimalField(Field): 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), - 'max_string_length': _('String value too large.'), - 'invalid_rounding': _('Invalid rounding option {rounding}. Valid values for rounding are: {valid_roundings}') + 'max_string_length': _('String value too large.') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. @@ -1028,9 +1027,10 @@ class DecimalField(Field): self.validators.append( MinValueValidator(self.min_value, message=message)) - valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] - if rounding is not None and rounding not in valid_roundings: - self.fail('invalid_rounding', rounding=rounding, valid_roundings=valid_roundings) + if rounding is not None: + valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] + assert rounding in valid_roundings, \ + 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings) self.rounding = rounding def to_internal_value(self, data): From 6d3d82ae544466b8c3c30b1f2977ded30d417a06 Mon Sep 17 00:00:00 2001 From: Trang Tran Date: Fri, 3 Nov 2017 22:37:25 +0700 Subject: [PATCH 3/3] add testcase and PEP8 multilines fix --- rest_framework/fields.py | 4 ++-- tests/test_fields.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1e7981ecc..95c822fb1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1029,8 +1029,8 @@ class DecimalField(Field): if rounding is not None: valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] - assert rounding in valid_roundings, \ - 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings) + assert rounding in valid_roundings, ( + 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings)) self.rounding = rounding def to_internal_value(self, data): diff --git a/tests/test_fields.py b/tests/test_fields.py index c1b99818a..34372dceb 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1092,6 +1092,22 @@ class TestNoDecimalPlaces(FieldValues): field = serializers.DecimalField(max_digits=6, decimal_places=None) +class TestRoundingDecimalField(TestCase): + def test_valid_rounding(self): + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_UP') + assert field.to_representation(Decimal('1.234')) == '1.24' + + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_DOWN') + assert field.to_representation(Decimal('1.234')) == '1.23' + + def test_invalid_rounding(self): + with pytest.raises(AssertionError) as excinfo: + serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') + assert 'Invalid rounding option' in str(excinfo.value) + + + + # Date & time serializers... class TestDateField(FieldValues):