diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 3d2443c5c..64014b56e 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -269,6 +269,7 @@ Corresponds to `django.db.models.fields.DecimalField`. - `max_value` Validate that the number provided is no greater than this value. - `min_value` Validate that the number provided is no less than this value. - `localize` Set to `True` to enable localization of input and output based on the current locale. This will also force `coerce_to_string` to `True`. Defaults to `False`. Note that data formatting is enabled if you have set `USE_L10N=True` in your settings file. +- `rounding` Sets the rounding mode used when quantising to the configured precision. Valid values are [`decimal` module rounding modes][python-decimal-rounding-modes]. Defaults to `None`. #### Example usage @@ -680,3 +681,4 @@ The [django-rest-framework-hstore][django-rest-framework-hstore] package provide [django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis [django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore [django-hstore]: https://github.com/djangonauts/django-hstore +[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes \ No newline at end of file diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9cfd39995..dd852f3c6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -997,7 +997,7 @@ class DecimalField(Field): MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, - localize=False, **kwargs): + localize=False, rounding=None, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.localize = localize @@ -1029,6 +1029,12 @@ class DecimalField(Field): self.validators.append( MinValueValidator(self.min_value, message=message)) + if rounding is not None: + valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] + assert rounding in valid_roundings, ( + 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings)) + self.rounding = rounding + def to_internal_value(self, data): """ Validate that the input is a decimal number and return a Decimal @@ -1121,6 +1127,7 @@ class DecimalField(Field): context.prec = self.max_digits return value.quantize( decimal.Decimal('.1') ** self.decimal_places, + rounding=self.rounding, context=context ) diff --git a/tests/test_fields.py b/tests/test_fields.py index 2f642a77c..876dd39c8 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -3,7 +3,7 @@ import os import re import unittest import uuid -from decimal import Decimal +from decimal import ROUND_DOWN, ROUND_UP, Decimal import django import pytest @@ -1092,8 +1092,21 @@ class TestNoDecimalPlaces(FieldValues): field = serializers.DecimalField(max_digits=6, decimal_places=None) -# Date & time serializers... +class TestRoundingDecimalField(TestCase): + def test_valid_rounding(self): + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) + assert field.to_representation(Decimal('1.234')) == '1.24' + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN) + assert field.to_representation(Decimal('1.234')) == '1.23' + + def test_invalid_rounding(self): + with pytest.raises(AssertionError) as excinfo: + serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') + assert 'Invalid rounding option' in str(excinfo.value) + + +# Date & time serializers... class TestDateField(FieldValues): """ Valid and invalid values for `DateField`.