mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 04:20:12 +03:00
Merge 6d3d82ae54
into 565c722762
This commit is contained in:
commit
fe7faf6433
|
@ -997,7 +997,7 @@ class DecimalField(Field):
|
||||||
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
|
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
|
||||||
|
|
||||||
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
|
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
|
||||||
localize=False, **kwargs):
|
localize=False, rounding=None, **kwargs):
|
||||||
self.max_digits = max_digits
|
self.max_digits = max_digits
|
||||||
self.decimal_places = decimal_places
|
self.decimal_places = decimal_places
|
||||||
self.localize = localize
|
self.localize = localize
|
||||||
|
@ -1029,6 +1029,12 @@ class DecimalField(Field):
|
||||||
self.validators.append(
|
self.validators.append(
|
||||||
MinValueValidator(self.min_value, message=message))
|
MinValueValidator(self.min_value, message=message))
|
||||||
|
|
||||||
|
if rounding is not None:
|
||||||
|
valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')]
|
||||||
|
assert rounding in valid_roundings, (
|
||||||
|
'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings))
|
||||||
|
self.rounding = rounding
|
||||||
|
|
||||||
def to_internal_value(self, data):
|
def to_internal_value(self, data):
|
||||||
"""
|
"""
|
||||||
Validate that the input is a decimal number and return a Decimal
|
Validate that the input is a decimal number and return a Decimal
|
||||||
|
@ -1121,6 +1127,7 @@ class DecimalField(Field):
|
||||||
context.prec = self.max_digits
|
context.prec = self.max_digits
|
||||||
return value.quantize(
|
return value.quantize(
|
||||||
decimal.Decimal('.1') ** self.decimal_places,
|
decimal.Decimal('.1') ** self.decimal_places,
|
||||||
|
rounding=self.rounding,
|
||||||
context=context
|
context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1092,6 +1092,22 @@ class TestNoDecimalPlaces(FieldValues):
|
||||||
field = serializers.DecimalField(max_digits=6, decimal_places=None)
|
field = serializers.DecimalField(max_digits=6, decimal_places=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoundingDecimalField(TestCase):
|
||||||
|
def test_valid_rounding(self):
|
||||||
|
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_UP')
|
||||||
|
assert field.to_representation(Decimal('1.234')) == '1.24'
|
||||||
|
|
||||||
|
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding='ROUND_DOWN')
|
||||||
|
assert field.to_representation(Decimal('1.234')) == '1.23'
|
||||||
|
|
||||||
|
def test_invalid_rounding(self):
|
||||||
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
|
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
|
||||||
|
assert 'Invalid rounding option' in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Date & time serializers...
|
# Date & time serializers...
|
||||||
|
|
||||||
class TestDateField(FieldValues):
|
class TestDateField(FieldValues):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user