From 040bfcc09c851bb3dadd60558c78a1f7937e9fbd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 21:48:54 +0100 Subject: [PATCH] NotImplemented stubs for Field, and DecimalField improvements --- rest_framework/fields.py | 27 +++++++++++++++++++++------ rest_framework/pagination.py | 4 ++-- rest_framework/utils/encoders.py | 2 +- tests/test_filters.py | 10 +++++----- tests/test_pagination.py | 4 ++-- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7496a629e..20b8ffbff 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -229,13 +229,13 @@ class Field(object): """ Transform the *incoming* primative data into a native value. """ - return data + raise NotImplementedError('to_native() must be implemented.') def to_primative(self, value): """ Transform the *outgoing* native value into primative data. """ - return value + raise NotImplementedError('to_primative() must be implemented.') def fail(self, key, **kwargs): """ @@ -429,9 +429,10 @@ class DecimalField(Field): 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') } - def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): - self.max_value, self.min_value = max_value, min_value - self.max_digits, self.max_decimal_places = max_digits, decimal_places + def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.coerce_to_string = coerce_to_string super(DecimalField, self).__init__(**kwargs) if max_value is not None: self.validators.append(validators.MaxValueValidator(max_value)) @@ -478,12 +479,26 @@ class DecimalField(Field): if self.max_digits is not None and digits > self.max_digits: self.fail('max_digits', max_digits=self.max_digits) if self.decimal_places is not None and decimals > self.decimal_places: - self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) + self.fail('max_decimal_places', max_decimal_places=self.decimal_places) if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) return value + def to_primative(self, value): + if not self.coerce_to_string: + return value + + if isinstance(value, decimal.Decimal): + context = decimal.getcontext().copy() + context.prec = self.max_digits + quantized = value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + context=context + ) + return '{0:f}'.format(quantized) + return '%.*f' % (self.max_decimal_places, value) + # Date & time fields... diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 9cf31629f..d82d2d3b3 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) -class DefaultObjectSerializer(serializers.Field): +class DefaultObjectSerializer(serializers.ReadOnlyField): """ If no object serializer is specified, then this serializer will be applied as the default. @@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer): """ A default implementation of a pagination serializer. """ - count = serializers.Field(source='paginator.count') + count = serializers.ReadOnlyField(source='paginator.count') next = NextPageField(source='*') previous = PreviousPageField(source='*') diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 6a2f61266..7992b6b1c 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder): elif isinstance(o, datetime.timedelta): return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): - return str(o) + return float(o) elif isinstance(o, QuerySet): return list(o) elif hasattr(o, 'tolist'): diff --git a/tests/test_filters.py b/tests/test_filters.py index 300e47e45..01668114b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -102,7 +102,7 @@ if django_filters: class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): - return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date} def setUp(self): """ @@ -145,7 +145,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] self.assertEqual(response.data, expected_data) # Tests that the date filter works. @@ -168,7 +168,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filter not installed') @@ -201,7 +201,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] < search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal] self.assertEqual(response.data, expected_data) # Tests that the date filter set with 'gt' in the filter class works. @@ -230,7 +230,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] + Decimal(f['decimal']) < search_decimal] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filter not installed') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 68983ba20..a7f8e691f 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -135,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date} for obj in self.objects.all() ] @@ -381,7 +381,7 @@ class TestMaxPaginateByParam(TestCase): # Tests for context in pagination serializers -class CustomField(serializers.Field): +class CustomField(serializers.ReadOnlyField): def to_native(self, value): if 'view' not in self.context: raise RuntimeError("context isn't getting passed into custom field")