mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
NotImplemented stubs for Field, and DecimalField improvements
This commit is contained in:
parent
a751871991
commit
040bfcc09c
|
@ -229,13 +229,13 @@ class Field(object):
|
|||
"""
|
||||
Transform the *incoming* primative data into a native value.
|
||||
"""
|
||||
return data
|
||||
raise NotImplementedError('to_native() must be implemented.')
|
||||
|
||||
def to_primative(self, value):
|
||||
"""
|
||||
Transform the *outgoing* native value into primative data.
|
||||
"""
|
||||
return value
|
||||
raise NotImplementedError('to_primative() must be implemented.')
|
||||
|
||||
def fail(self, key, **kwargs):
|
||||
"""
|
||||
|
@ -429,9 +429,10 @@ class DecimalField(Field):
|
|||
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
|
||||
}
|
||||
|
||||
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
|
||||
self.max_value, self.min_value = max_value, min_value
|
||||
self.max_digits, self.max_decimal_places = max_digits, decimal_places
|
||||
def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs):
|
||||
self.max_digits = max_digits
|
||||
self.decimal_places = decimal_places
|
||||
self.coerce_to_string = coerce_to_string
|
||||
super(DecimalField, self).__init__(**kwargs)
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
|
@ -478,12 +479,26 @@ class DecimalField(Field):
|
|||
if self.max_digits is not None and digits > self.max_digits:
|
||||
self.fail('max_digits', max_digits=self.max_digits)
|
||||
if self.decimal_places is not None and decimals > self.decimal_places:
|
||||
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
|
||||
self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
|
||||
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
|
||||
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
|
||||
|
||||
return value
|
||||
|
||||
def to_primative(self, value):
|
||||
if not self.coerce_to_string:
|
||||
return value
|
||||
|
||||
if isinstance(value, decimal.Decimal):
|
||||
context = decimal.getcontext().copy()
|
||||
context.prec = self.max_digits
|
||||
quantized = value.quantize(
|
||||
decimal.Decimal('.1') ** self.decimal_places,
|
||||
context=context
|
||||
)
|
||||
return '{0:f}'.format(quantized)
|
||||
return '%.*f' % (self.max_decimal_places, value)
|
||||
|
||||
|
||||
# Date & time fields...
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):
|
|||
return replace_query_param(url, self.page_field, page)
|
||||
|
||||
|
||||
class DefaultObjectSerializer(serializers.Field):
|
||||
class DefaultObjectSerializer(serializers.ReadOnlyField):
|
||||
"""
|
||||
If no object serializer is specified, then this serializer will be applied
|
||||
as the default.
|
||||
|
@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer):
|
|||
"""
|
||||
A default implementation of a pagination serializer.
|
||||
"""
|
||||
count = serializers.Field(source='paginator.count')
|
||||
count = serializers.ReadOnlyField(source='paginator.count')
|
||||
next = NextPageField(source='*')
|
||||
previous = PreviousPageField(source='*')
|
||||
|
|
|
@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder):
|
|||
elif isinstance(o, datetime.timedelta):
|
||||
return str(o.total_seconds())
|
||||
elif isinstance(o, decimal.Decimal):
|
||||
return str(o)
|
||||
return float(o)
|
||||
elif isinstance(o, QuerySet):
|
||||
return list(o)
|
||||
elif hasattr(o, 'tolist'):
|
||||
|
|
|
@ -102,7 +102,7 @@ if django_filters:
|
|||
|
||||
class CommonFilteringTestCase(TestCase):
|
||||
def _serialize_object(self, obj):
|
||||
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||
return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date}
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
|
@ -145,7 +145,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
|||
request = factory.get('/', {'decimal': '%s' % search_decimal})
|
||||
response = view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
|
||||
expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
|
||||
self.assertEqual(response.data, expected_data)
|
||||
|
||||
# Tests that the date filter works.
|
||||
|
@ -168,7 +168,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
|||
request = factory.get('/', {'decimal': '%s' % search_decimal})
|
||||
response = view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
|
||||
expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
|
||||
self.assertEqual(response.data, expected_data)
|
||||
|
||||
@unittest.skipUnless(django_filters, 'django-filter not installed')
|
||||
|
@ -201,7 +201,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
|||
request = factory.get('/', {'decimal': '%s' % search_decimal})
|
||||
response = view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
expected_data = [f for f in self.data if f['decimal'] < search_decimal]
|
||||
expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal]
|
||||
self.assertEqual(response.data, expected_data)
|
||||
|
||||
# Tests that the date filter set with 'gt' in the filter class works.
|
||||
|
@ -230,7 +230,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
|||
response = view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
expected_data = [f for f in self.data if f['date'] > search_date and
|
||||
f['decimal'] < search_decimal]
|
||||
Decimal(f['decimal']) < search_decimal]
|
||||
self.assertEqual(response.data, expected_data)
|
||||
|
||||
@unittest.skipUnless(django_filters, 'django-filter not installed')
|
||||
|
|
|
@ -135,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
|
|||
|
||||
self.objects = FilterableItem.objects
|
||||
self.data = [
|
||||
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||
{'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date}
|
||||
for obj in self.objects.all()
|
||||
]
|
||||
|
||||
|
@ -381,7 +381,7 @@ class TestMaxPaginateByParam(TestCase):
|
|||
|
||||
# Tests for context in pagination serializers
|
||||
|
||||
class CustomField(serializers.Field):
|
||||
class CustomField(serializers.ReadOnlyField):
|
||||
def to_native(self, value):
|
||||
if 'view' not in self.context:
|
||||
raise RuntimeError("context isn't getting passed into custom field")
|
||||
|
|
Loading…
Reference in New Issue
Block a user