Add rounding parameter to DecimalField (#5562)

* Adding rounding parameter to DecimalField.

* Using standard `assert` instead of `self.fail()`.

* add testcase and PEP8 multilines fix

* flake8 fixes

* Use decimal module constants in tests.

* Add docs note for `rounding` parameter.
This commit is contained in:
Carlton Gibson 2017-11-06 09:55:09 +01:00 committed by GitHub
parent 565c722762
commit 331c31370f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 3 deletions

View File

@ -269,6 +269,7 @@ Corresponds to `django.db.models.fields.DecimalField`.
- `max_value` Validate that the number provided is no greater than this value. - `max_value` Validate that the number provided is no greater than this value.
- `min_value` Validate that the number provided is no less than this value. - `min_value` Validate that the number provided is no less than this value.
- `localize` Set to `True` to enable localization of input and output based on the current locale. This will also force `coerce_to_string` to `True`. Defaults to `False`. Note that data formatting is enabled if you have set `USE_L10N=True` in your settings file. - `localize` Set to `True` to enable localization of input and output based on the current locale. This will also force `coerce_to_string` to `True`. Defaults to `False`. Note that data formatting is enabled if you have set `USE_L10N=True` in your settings file.
- `rounding` Sets the rounding mode used when quantising to the configured precision. Valid values are [`decimal` module rounding modes][python-decimal-rounding-modes]. Defaults to `None`.
#### Example usage #### Example usage
@ -680,3 +681,4 @@ The [django-rest-framework-hstore][django-rest-framework-hstore] package provide
[django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis [django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis
[django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore [django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore
[django-hstore]: https://github.com/djangonauts/django-hstore [django-hstore]: https://github.com/djangonauts/django-hstore
[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes

View File

@ -997,7 +997,7 @@ class DecimalField(Field):
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
localize=False, **kwargs): localize=False, rounding=None, **kwargs):
self.max_digits = max_digits self.max_digits = max_digits
self.decimal_places = decimal_places self.decimal_places = decimal_places
self.localize = localize self.localize = localize
@ -1029,6 +1029,12 @@ class DecimalField(Field):
self.validators.append( self.validators.append(
MinValueValidator(self.min_value, message=message)) MinValueValidator(self.min_value, message=message))
if rounding is not None:
valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')]
assert rounding in valid_roundings, (
'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings))
self.rounding = rounding
def to_internal_value(self, data): def to_internal_value(self, data):
""" """
Validate that the input is a decimal number and return a Decimal Validate that the input is a decimal number and return a Decimal
@ -1121,6 +1127,7 @@ class DecimalField(Field):
context.prec = self.max_digits context.prec = self.max_digits
return value.quantize( return value.quantize(
decimal.Decimal('.1') ** self.decimal_places, decimal.Decimal('.1') ** self.decimal_places,
rounding=self.rounding,
context=context context=context
) )

View File

@ -3,7 +3,7 @@ import os
import re import re
import unittest import unittest
import uuid import uuid
from decimal import Decimal from decimal import ROUND_DOWN, ROUND_UP, Decimal
import django import django
import pytest import pytest
@ -1092,8 +1092,21 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None) field = serializers.DecimalField(max_digits=6, decimal_places=None)
# Date & time serializers... class TestRoundingDecimalField(TestCase):
def test_valid_rounding(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
assert field.to_representation(Decimal('1.234')) == '1.24'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23'
def test_invalid_rounding(self):
with pytest.raises(AssertionError) as excinfo:
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
assert 'Invalid rounding option' in str(excinfo.value)
# Date & time serializers...
class TestDateField(FieldValues): class TestDateField(FieldValues):
""" """
Valid and invalid values for `DateField`. Valid and invalid values for `DateField`.