mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
DecimalField
This commit is contained in:
commit
d985aec3c9
|
@ -248,6 +248,12 @@ A floating point representation.
|
|||
|
||||
Corresponds to `django.db.models.fields.FloatField`.
|
||||
|
||||
## DecimalField
|
||||
|
||||
A decimal representation.
|
||||
|
||||
Corresponds to `django.db.models.fields.DecimalField`.
|
||||
|
||||
## FileField
|
||||
|
||||
A file representation. Performs Django's standard FileField validation.
|
||||
|
|
|
@ -118,6 +118,12 @@ And would have the following entry in the urlconf:
|
|||
|
||||
Usage of the old-style attributes continues to be supported, but will raise a `PendingDeprecationWarning`.
|
||||
|
||||
## DecimalField
|
||||
|
||||
2.3 introduces a `DecimalField` serializer field, which returns `Decimal` instances.
|
||||
|
||||
For most cases APIs using model fields will behave as previously, however if you are using a custom renderer, not provided by REST framework, then you may now need to add support for rendering `Decimal` instances to your renderer implmentation.
|
||||
|
||||
---
|
||||
|
||||
# Other notes
|
||||
|
|
|
@ -38,6 +38,20 @@ You can determine your currently installed version using `pip freeze`:
|
|||
|
||||
---
|
||||
|
||||
## 2.3.x series
|
||||
|
||||
### 2.3.0
|
||||
|
||||
* ViewSets and Routers.
|
||||
* ModelSerializers support reverse relations in 'fields' option.
|
||||
* HyperLinkedModelSerializers support 'id' field in 'fields' option.
|
||||
* Cleaner generic views.
|
||||
* DecimalField support.
|
||||
|
||||
**Note**: See the [2.3 announcement][2.3-announcement] for full details.
|
||||
|
||||
---
|
||||
|
||||
## 2.2.x series
|
||||
|
||||
### 2.2.7
|
||||
|
@ -458,6 +472,7 @@ This change will not affect user code, so long as it's following the recommended
|
|||
[django-deprecation-policy]: https://docs.djangoproject.com/en/dev/internals/release-process/#internal-release-deprecation-policy
|
||||
[defusedxml-announce]: http://blog.python.org/2013/02/announcing-defusedxml-fixes-for-xml.html
|
||||
[2.2-announcement]: 2.2-announcement.md
|
||||
[2.3-announcement]: 2.3-announcement.md
|
||||
[743]: https://github.com/tomchristie/django-rest-framework/pull/743
|
||||
[staticfiles14]: https://docs.djangoproject.com/en/1.4/howto/static-files/#with-a-template-tag
|
||||
[staticfiles13]: https://docs.djangoproject.com/en/1.3/howto/static-files/#with-a-template-tag
|
||||
|
|
|
@ -7,6 +7,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import copy
|
||||
import datetime
|
||||
from decimal import Decimal, DecimalException
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
|
@ -726,6 +727,75 @@ class FloatField(WritableField):
|
|||
raise ValidationError(msg)
|
||||
|
||||
|
||||
class DecimalField(WritableField):
|
||||
type_name = 'DecimalField'
|
||||
form_field_class = forms.DecimalField
|
||||
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a number.'),
|
||||
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
|
||||
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
|
||||
'max_digits': _('Ensure that there are no more than %s digits in total.'),
|
||||
'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
|
||||
'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
|
||||
}
|
||||
|
||||
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
|
||||
self.max_value, self.min_value = max_value, min_value
|
||||
self.max_digits, self.decimal_places = max_digits, decimal_places
|
||||
super(DecimalField, self).__init__(*args, **kwargs)
|
||||
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def from_native(self, value):
|
||||
"""
|
||||
Validates that the input is a decimal number. Returns a Decimal
|
||||
instance. Returns None for empty values. Ensures that there are no more
|
||||
than max_digits in the number, and no more than decimal_places digits
|
||||
after the decimal point.
|
||||
"""
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
value = smart_text(value).strip()
|
||||
try:
|
||||
value = Decimal(value)
|
||||
except DecimalException:
|
||||
raise ValidationError(self.error_messages['invalid'])
|
||||
return value
|
||||
|
||||
def validate(self, value):
|
||||
super(DecimalField, self).validate(value)
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return
|
||||
# Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
|
||||
# since it is never equal to itself. However, NaN is the only value that
|
||||
# isn't equal to itself, so we can use this to identify NaN
|
||||
if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
|
||||
raise ValidationError(self.error_messages['invalid'])
|
||||
sign, digittuple, exponent = value.as_tuple()
|
||||
decimals = abs(exponent)
|
||||
# digittuple doesn't include any leading zeros.
|
||||
digits = len(digittuple)
|
||||
if decimals > digits:
|
||||
# We have leading zeros up to or past the decimal point. Count
|
||||
# everything past the decimal point as a digit. We do not count
|
||||
# 0 before the decimal point as a digit since that would mean
|
||||
# we would not allow max_digits = decimal_places.
|
||||
digits = decimals
|
||||
whole_digits = digits - decimals
|
||||
|
||||
if self.max_digits is not None and digits > self.max_digits:
|
||||
raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
|
||||
if self.decimal_places is not None and decimals > self.decimal_places:
|
||||
raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
|
||||
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
|
||||
raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
|
||||
return value
|
||||
|
||||
|
||||
class FileField(WritableField):
|
||||
use_files = True
|
||||
type_name = 'FileField'
|
||||
|
|
|
@ -560,6 +560,7 @@ class ModelSerializer(Serializer):
|
|||
models.DateTimeField: DateTimeField,
|
||||
models.DateField: DateField,
|
||||
models.TimeField: TimeField,
|
||||
models.DecimalField: DecimalField,
|
||||
models.EmailField: EmailField,
|
||||
models.CharField: CharField,
|
||||
models.URLField: URLField,
|
||||
|
|
|
@ -3,12 +3,14 @@ General serializer field tests.
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from django.core import validators
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
|
||||
class TimestampedModel(models.Model):
|
||||
|
@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):
|
|||
self.assertEqual('04 - 00 [000000]', result_1)
|
||||
self.assertEqual('04 - 59 [000000]', result_2)
|
||||
self.assertEqual('04 - 59 [000200]', result_3)
|
||||
|
||||
|
||||
class DecimalFieldTest(TestCase):
|
||||
"""
|
||||
Tests for the DecimalField from_native() and to_native() behavior
|
||||
"""
|
||||
|
||||
def test_from_native_string(self):
|
||||
"""
|
||||
Make sure from_native() accepts string values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result_1 = f.from_native('9000')
|
||||
result_2 = f.from_native('1.00000001')
|
||||
|
||||
self.assertEqual(Decimal('9000'), result_1)
|
||||
self.assertEqual(Decimal('1.00000001'), result_2)
|
||||
|
||||
def test_from_native_invalid_string(self):
|
||||
"""
|
||||
Make sure from_native() raises ValidationError on passing invalid string
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
|
||||
try:
|
||||
f.from_native('123.45.6')
|
||||
except validators.ValidationError as e:
|
||||
self.assertEqual(e.messages, ["Enter a number."])
|
||||
else:
|
||||
self.fail("ValidationError was not properly raised")
|
||||
|
||||
def test_from_native_integer(self):
|
||||
"""
|
||||
Make sure from_native() accepts integer values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(9000)
|
||||
|
||||
self.assertEqual(Decimal('9000'), result)
|
||||
|
||||
def test_from_native_float(self):
|
||||
"""
|
||||
Make sure from_native() accepts float values
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(1.00000001)
|
||||
|
||||
self.assertEqual(Decimal('1.00000001'), result)
|
||||
|
||||
def test_from_native_empty(self):
|
||||
"""
|
||||
Make sure from_native() returns None on empty param.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native('')
|
||||
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_from_native_none(self):
|
||||
"""
|
||||
Make sure from_native() returns None on None param.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
result = f.from_native(None)
|
||||
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_to_native(self):
|
||||
"""
|
||||
Make sure to_native() returns Decimal as string.
|
||||
"""
|
||||
f = serializers.DecimalField()
|
||||
|
||||
result_1 = f.to_native(Decimal('9000'))
|
||||
result_2 = f.to_native(Decimal('1.00000001'))
|
||||
|
||||
self.assertEqual(Decimal('9000'), result_1)
|
||||
self.assertEqual(Decimal('1.00000001'), result_2)
|
||||
|
||||
def test_to_native_none(self):
|
||||
"""
|
||||
Make sure from_native() returns None on None param.
|
||||
"""
|
||||
f = serializers.DecimalField(required=False)
|
||||
self.assertEqual(None, f.to_native(None))
|
||||
|
||||
def test_valid_serialization(self):
|
||||
"""
|
||||
Make sure the serializer works correctly
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_value=9010,
|
||||
min_value=9000,
|
||||
max_digits=6,
|
||||
decimal_places=2)
|
||||
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
|
||||
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
|
||||
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
|
||||
self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
|
||||
|
||||
def test_raise_max_value(self):
|
||||
"""
|
||||
Make sure max_value violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_value=100)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
|
||||
|
||||
def test_raise_min_value(self):
|
||||
"""
|
||||
Make sure min_value violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(min_value=100)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '99'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
|
||||
|
||||
def test_raise_max_digits(self):
|
||||
"""
|
||||
Make sure max_digits violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_digits=5)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123.456'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
|
||||
|
||||
def test_raise_max_decimal_places(self):
|
||||
"""
|
||||
Make sure max_decimal_places violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(decimal_places=3)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '123.4567'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
|
||||
|
||||
def test_raise_max_whole_digits(self):
|
||||
"""
|
||||
Make sure max_whole_digits violations raises ValidationError
|
||||
"""
|
||||
class DecimalSerializer(Serializer):
|
||||
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
|
||||
|
||||
s = DecimalSerializer(data={'decimal_field': '12345.6'})
|
||||
|
||||
self.assertFalse(s.is_valid())
|
||||
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
|
Loading…
Reference in New Issue
Block a user