mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
Add DecimalField support
This commit is contained in:
parent
0c1b8b4f76
commit
ad436d966f
|
@ -248,6 +248,12 @@ A floating point representation.
|
|||
|
||||
Corresponds to `django.db.models.fields.FloatField`.
|
||||
|
||||
## DecimalField
|
||||
|
||||
A decimal representation.
|
||||
|
||||
Corresponds to `django.db.models.fields.DecimalField`.
|
||||
|
||||
## FileField
|
||||
|
||||
A file representation. Performs Django's standard FileField validation.
|
||||
|
|
|
@ -44,6 +44,7 @@ You can determine your currently installed version using `pip freeze`:
|
|||
|
||||
**Date**: 4th April 2013
|
||||
|
||||
* DecimalField support.
|
||||
* OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token.
|
||||
* URL hyperlinking in browseable API now handles more cases correctly.
|
||||
* Long HTTP headers in browsable API are broken in multiple lines when possible.
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import copy
|
||||
import datetime
|
||||
from decimal import Decimal, DecimalException
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
|
@ -721,6 +722,80 @@ class FloatField(WritableField):
|
|||
raise ValidationError(msg)
|
||||
|
||||
|
||||
class DecimalField(WritableField):
|
||||
type_name = 'DecimalField'
|
||||
form_field_class = forms.DecimalField
|
||||
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a number.'),
|
||||
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
|
||||
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
|
||||
'max_digits': _('Ensure that there are no more than %s digits in total.'),
|
||||
'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
|
||||
'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
|
||||
}
|
||||
|
||||
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
|
||||
self.max_value, self.min_value = max_value, min_value
|
||||
self.max_digits, self.decimal_places = max_digits, decimal_places
|
||||
super(DecimalField, self).__init__(self, *args, **kwargs)
|
||||
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def from_native(self, value):
|
||||
"""
|
||||
Validates that the input is a decimal number. Returns a Decimal
|
||||
instance. Returns None for empty values. Ensures that there are no more
|
||||
than max_digits in the number, and no more than decimal_places digits
|
||||
after the decimal point.
|
||||
"""
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
value = smart_text(value).strip()
|
||||
try:
|
||||
value = Decimal(value)
|
||||
except DecimalException:
|
||||
raise ValidationError(self.error_messages['invalid'])
|
||||
return value
|
||||
|
||||
def to_native(self, value):
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
def validate(self, value):
|
||||
super(DecimalField, self).validate(value)
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return
|
||||
# Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
|
||||
# since it is never equal to itself. However, NaN is the only value that
|
||||
# isn't equal to itself, so we can use this to identify NaN
|
||||
if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
|
||||
raise ValidationError(self.error_messages['invalid'])
|
||||
sign, digittuple, exponent = value.as_tuple()
|
||||
decimals = abs(exponent)
|
||||
# digittuple doesn't include any leading zeros.
|
||||
digits = len(digittuple)
|
||||
if decimals > digits:
|
||||
# We have leading zeros up to or past the decimal point. Count
|
||||
# everything past the decimal point as a digit. We do not count
|
||||
# 0 before the decimal point as a digit since that would mean
|
||||
# we would not allow max_digits = decimal_places.
|
||||
digits = decimals
|
||||
whole_digits = digits - decimals
|
||||
|
||||
if self.max_digits is not None and digits > self.max_digits:
|
||||
raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
|
||||
if self.decimal_places is not None and decimals > self.decimal_places:
|
||||
raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
|
||||
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
|
||||
raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
|
||||
return value
|
||||
|
||||
|
||||
class FileField(WritableField):
|
||||
use_files = True
|
||||
type_name = 'FileField'
|
||||
|
|
|
@ -3,12 +3,14 @@ General serializer field tests.
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from django.core import validators
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
|
||||
class TimestampedModel(models.Model):
|
||||
|
@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):
|
|||
self.assertEqual('04 - 00 [000000]', result_1)
|
||||
self.assertEqual('04 - 59 [000000]', result_2)
|
||||
self.assertEqual('04 - 59 [000200]', result_3)
|
||||
|
||||
|
||||
class DecimalFieldTest(TestCase):
|
||||
"""
|
||||
Tests for the DecimalField from_native() and to_native() behavior
|
||||
"""
|
||||
|
||||
def test_from_native_string(self):
|
||||
"""
|
||||
Make sure from_native() accepts string values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result_1 = f.from_native('9000')
|
||||
result_2 = f.from_native('1.00000001')
|
||||
|
||||
self.assertEqual(Decimal('9000'), result_1)
|
||||
self.assertEqual(Decimal('1.00000001'), result_2)
|
||||
|
||||
def test_from_native_invalid_string(self):
|
||||
"""
|
||||
Make sure from_native() raises ValidationError on passing invalid string
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
|
||||
try:
|
||||
f.from_native('123.45.6')
|
||||
except validators.ValidationError as e:
|
||||
self.assertEqual(e.messages, ["Enter a number."])
|
||||
else:
|
||||
self.fail("ValidationError was not properly raised")
|
||||
|
||||
def test_from_native_integer(self):
|
||||
"""
|
||||
Make sure from_native() accepts integer values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(9000)
|
||||
|
||||
self.assertEqual(Decimal('9000'), result)
|
||||
|
||||
def test_from_native_float(self):
|
||||
"""
|
||||
Make sure from_native() accepts float values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(1.00000001)
|
||||
|
||||
self.assertEqual(Decimal('1.00000001'), result)
|
||||
|
||||
def test_from_native_empty(self):
|
||||
"""
|
||||
Make sure from_native() returns None on empty param.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native('')
|
||||
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_from_native_none(self):
|
||||
"""
|
||||
Make sure from_native() returns None on None param.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(None)
|
||||
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_to_native(self):
|
||||
"""
|
||||
Make sure to_native() returns Decimal as string.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
|
||||
result_1 = f.to_native(Decimal('9000'))
|
||||
result_2 = f.to_native(Decimal('1.00000001'))
|
||||
|
||||
self.assertEqual('9000', result_1)
|
||||
self.assertEqual('1.00000001', result_2)
|
||||
|
||||
def test_to_native_none(self):
|
||||
"""
|
||||
Make sure from_native() returns None on None param.
|
||||
"""
|
||||
f = serializers.DecimalField(required=False)
|
||||
self.assertEqual(None, f.to_native(None))
|
||||
|
||||
def test_valid_serialization(self):
|
||||
"""
|
||||
Make sure the serializer works correctly
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_value=9010,
|
||||
min_value=9000,
|
||||
max_digits=6,
|
||||
decimal_places=2)
|
||||
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
|
||||
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
|
||||
|
||||
def test_raise_max_value(self):
|
||||
"""
|
||||
Make sure max_value violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_value=100)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is less than or equal to 100.']})
|
||||
|
||||
def test_raise_min_value(self):
|
||||
"""
|
||||
Make sure min_value violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(min_value=100)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '99'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is greater than or equal to 100.']})
|
||||
|
||||
def test_raise_max_digits(self):
|
||||
"""
|
||||
Make sure max_digits violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_digits=5)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123.456'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']})
|
||||
|
||||
def test_raise_max_decimal_places(self):
|
||||
"""
|
||||
Make sure max_decimal_places violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(decimal_places=3)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123.4567'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']})
|
||||
|
||||
def test_raise_max_whole_digits(self):
|
||||
"""
|
||||
Make sure max_whole_digits violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '12345.6'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']})
|
Loading…
Reference in New Issue
Block a user