diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6caae9242..d3ca44a25 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -878,6 +878,10 @@ class DecimalField(WritableField): def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): self.max_value, self.min_value = max_value, min_value self.max_digits, self.decimal_places = max_digits, decimal_places + + if self.decimal_places: + self.empty = Decimal('0').quantize(Decimal('.%s1' % ('0' * (self.decimal_places - 1)))) + super(DecimalField, self).__init__(*args, **kwargs) if max_value is not None: diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 17d12f231..fa0ca1fa3 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -637,7 +637,35 @@ class DecimalFieldTest(TestCase): self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) - + + def test_decimal_is_not_quantized_when_decimal_places_is_none(self): + class LongDecimalFieldModel(models.Model): + decimal_field = models.DecimalField(max_digits=20, decimal_places=10, default=0) + + class DecimalSerializer(serializers.ModelSerializer): + class Meta: + model = LongDecimalFieldModel + decimal_field = serializers.DecimalField(required=False, decimal_places=2) + + serializer = DecimalSerializer(data={}) + + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data['decimal_field'].as_tuple().exponent, 0) + + def test_decimal_is_quantized_when_decimal_places_is_provided(self): + class LongDecimalFieldModel(models.Model): + decimal_field = models.DecimalField(max_digits=20, decimal_places=10, default=0) + + class DecimalSerializer(serializers.ModelSerializer): + class Meta: + model = LongDecimalFieldModel + decimal_field = serializers.DecimalField(required=False, decimal_places=2) + + serializer = DecimalSerializer(data={}) + + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data['decimal_field'].as_tuple().exponent, -2) + def test_raise_max_value(self): """ Make sure max_value violations raises ValidationError