diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fc0c9a444..dd852f3c6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1031,8 +1031,8 @@ class DecimalField(Field): if rounding is not None: valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] - assert rounding in valid_roundings, \ - 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings) + assert rounding in valid_roundings, ( + 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings)) self.rounding = rounding def to_internal_value(self, data): diff --git a/tests/test_fields.py b/tests/test_fields.py index 2f642a77c..8c6360c3b 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1092,6 +1092,22 @@ class TestNoDecimalPlaces(FieldValues): field = serializers.DecimalField(max_digits=6, decimal_places=None) +class TestRoundingDecimalField(TestCase): + def test_valid_rounding(self): + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_UP') + assert field.to_representation(Decimal('1.234')) == '1.24' + + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_DOWN') + assert field.to_representation(Decimal('1.234')) == '1.23' + + def test_invalid_rounding(self): + with pytest.raises(AssertionError) as excinfo: + serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') + assert 'Invalid rounding option' in str(excinfo.value) + + + + # Date & time serializers... class TestDateField(FieldValues):