diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8b782a1c2..4a48a383a 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -60,8 +60,10 @@ def distinct(queryset, base): # contrib.postgres only supported from 1.8 onwards. try: from django.contrib.postgres import fields as postgres_fields + from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange except ImportError: - postgres_fields = None + postgres_fields = DateRange = DateTimeTZRange = NumericRange = None + # JSONField is only supported from 1.9 onwards diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 0b214b872..13a905e04 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -27,9 +27,9 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 from rest_framework.compat import ( - MaxLengthValidator, MaxValueValidator, MinLengthValidator, - MinValueValidator, duration_string, parse_duration, unicode_repr, - unicode_to_repr + DateRange, DateTimeTZRange, MaxLengthValidator, MaxValueValidator, + MinLengthValidator, MinValueValidator, NumericRange, duration_string, + parse_duration, unicode_repr, unicode_to_repr ) from rest_framework.exceptions import ValidationError from rest_framework.settings import api_settings @@ -1523,6 +1523,73 @@ class DictField(Field): } +class RangeField(DictField): + + range_type = None + + default_error_messages = { + 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".'), + 'too_much_content': _('Extra content not allowed "{extra}".'), + } + + def to_internal_value(self, data): + """ + Range instances <- Dicts of primitive datatypes. + """ + if html.is_html_input(data): + data = html.parse_html_dict(data) + if not isinstance(data, dict): + self.fail('not_a_dict', input_type=type(data).__name__) + validated_dict = {} + for key in ('lower', 'upper'): + try: + value = data.pop(key) + except KeyError: + continue + validated_dict[six.text_type(key)] = self.child.run_validation(value) + for key in ('bounds', 'empty'): + try: + value = data.pop(key) + except KeyError: + continue + validated_dict[six.text_type(key)] = value + if data: + self.fail('too_much_content', extra=', '.join(map(str, data.keys()))) + return self.range_type(**validated_dict) + + def to_representation(self, value): + """ + Range instances -> dicts of primitive datatypes. + """ + if value.isempty: + return {'empty': True} + lower = self.child.to_representation(value.lower) if value.lower is not None else None + upper = self.child.to_representation(value.upper) if value.upper is not None else None + return {'lower': lower, + 'upper': upper, + 'bounds': value._bounds} + + +class IntegerRangeField(RangeField): + child = IntegerField() + range_type = NumericRange + + +class FloatRangeField(RangeField): + child = FloatField() + range_type = NumericRange + + +class DateTimeRangeField(RangeField): + child = DateTimeField() + range_type = DateTimeTZRange + + +class DateRangeField(RangeField): + child = DateField() + range_type = DateRange + + class JSONField(Field): default_error_messages = { 'invalid': _('Value must be valid JSON.') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 99d36a8a5..3b9ba9a0d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1433,6 +1433,10 @@ if postgres_fields: ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField + ModelSerializer.serializer_field_mapping[postgres_fields.DateTimeRangeField] = DateTimeRangeField + ModelSerializer.serializer_field_mapping[postgres_fields.DateRangeField] = DateRangeField + ModelSerializer.serializer_field_mapping[postgres_fields.IntegerRangeField] = IntegerRangeField + ModelSerializer.serializer_field_mapping[postgres_fields.FloatRangeField] = FloatRangeField class HyperlinkedModelSerializer(ModelSerializer): diff --git a/tests/test_fields.py b/tests/test_fields.py index 104337627..9bb40ed30 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -6,10 +6,11 @@ from decimal import Decimal import django import pytest from django.http import QueryDict +from django.test import TestCase, override_settings from django.utils import timezone import rest_framework -from rest_framework import serializers +from rest_framework import compat, serializers # Tests for field keyword arguments and core functionality. @@ -1525,6 +1526,213 @@ class TestUnvalidatedDictField(FieldValues): field = serializers.DictField() +@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None, + reason='RangeField is only available for django1.8+' + ' and with psycopg2.') +class TestIntegerRangeField(FieldValues): + """ + Values for `ListField` with CharField as child. + """ + if compat.NumericRange is not None: + valid_inputs = [ + ({'lower': '1', 'upper': 2, 'bounds': '[)'}, + compat.NumericRange(**{'lower': 1, 'upper': 2, 'bounds': '[)'})), + ({'lower': 1, 'upper': 2}, + compat.NumericRange(**{'lower': 1, 'upper': 2})), + ({'lower': 1}, + compat.NumericRange(**{'lower': 1})), + ({'upper': 1}, + compat.NumericRange(**{'upper': 1})), + ({'empty': True}, + compat.NumericRange(**{'empty': True})), + ({}, compat.NumericRange()), + ] + invalid_inputs = [ + ({'lower': 'a'}, ['A valid integer is required.']), + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ({'foo': 'bar'}, ['Extra content not allowed "foo".']), + ] + outputs = [ + (compat.NumericRange(**{'lower': '1', 'upper': '2'}), + {'lower': 1, 'upper': 2, 'bounds': '[)'}), + (compat.NumericRange(**{'empty': True}), {'empty': True}), + (compat.NumericRange(), {'bounds': '[)', 'lower': None, 'upper': None}), + ] + field = serializers.IntegerRangeField() + + def test_no_source_on_child(self): + with pytest.raises(AssertionError) as exc_info: + serializers.IntegerRangeField(child=serializers.IntegerField(source='other')) + + assert str(exc_info.value) == ( + "The `source` argument is not meaningful when applied to a `child=` field. " + "Remove `source=` from the field declaration." + ) + + +@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None, + reason='RangeField is only available for django1.8+' + ' and with psycopg2.') +class TestFloatRangeField(FieldValues): + """ + Values for `ListField` with CharField as child. + """ + if compat.NumericRange is not None: + valid_inputs = [ + ({'lower': '1', 'upper': 2., 'bounds': '[)'}, + compat.NumericRange(**{'lower': 1., 'upper': 2., 'bounds': '[)'})), + ({'lower': 1., 'upper': 2.}, + compat.NumericRange(**{'lower': 1, 'upper': 2})), + ({'lower': 1}, + compat.NumericRange(**{'lower': 1})), + ({'upper': 1}, + compat.NumericRange(**{'upper': 1})), + ({'empty': True}, + compat.NumericRange(**{'empty': True})), + ({}, compat.NumericRange()), + ] + invalid_inputs = [ + ({'lower': 'a'}, ['A valid number is required.']), + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ] + outputs = [ + (compat.NumericRange(**{'lower': '1.1', 'upper': '2'}), + {'lower': 1.1, 'upper': 2, 'bounds': '[)'}), + (compat.NumericRange(**{'empty': True}), {'empty': True}), + (compat.NumericRange(), {'bounds': '[)', 'lower': None, 'upper': None}), + ] + field = serializers.FloatRangeField() + + def test_no_source_on_child(self): + with pytest.raises(AssertionError) as exc_info: + serializers.FloatRangeField(child=serializers.IntegerField(source='other')) + + assert str(exc_info.value) == ( + "The `source` argument is not meaningful when applied to a `child=` field. " + "Remove `source=` from the field declaration." + ) + + +@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None, + reason='RangeField is only available for django1.8+' + ' and with psycopg2.') +@override_settings(USE_TZ=True) +class TestDateTimeRangeField(TestCase, FieldValues): + """ + Values for `ListField` with CharField as child. + """ + if compat.DateTimeTZRange is not None: + valid_inputs = [ + ({'lower': '2001-01-01T13:00:00Z', + 'upper': '2001-02-02T13:00:00Z', + 'bounds': '[)'}, + compat.DateTimeTZRange( + **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + 'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=timezone.UTC()), + 'bounds': '[)'})), + ({'upper': '2001-02-02T13:00:00Z', + 'bounds': '[)'}, + compat.DateTimeTZRange( + **{'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=timezone.UTC()), + 'bounds': '[)'})), + ({'lower': '2001-01-01T13:00:00Z', + 'bounds': '[)'}, + compat.DateTimeTZRange( + **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + 'bounds': '[)'})), + ({'empty': True}, + compat.DateTimeTZRange(**{'empty': True})), + ({}, compat.DateTimeTZRange()), + ] + invalid_inputs = [ + ({'lower': 'a'}, ['Datetime has wrong format. Use one of these' + ' formats instead: ' + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].']), + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ] + outputs = [ + (compat.DateTimeTZRange( + **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + 'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=timezone.UTC())}), + {'lower': '2001-01-01T13:00:00Z', + 'upper': '2001-02-02T13:00:00Z', + 'bounds': '[)'}), + (compat.DateTimeTZRange(**{'empty': True}), + {'empty': True}), + (compat.DateTimeTZRange(), + {'bounds': '[)', 'lower': None, 'upper': None}), + ] + field = serializers.DateTimeRangeField() + + def test_no_source_on_child(self): + with pytest.raises(AssertionError) as exc_info: + serializers.DateTimeRangeField(child=serializers.IntegerField(source='other')) + + assert str(exc_info.value) == ( + "The `source` argument is not meaningful when applied to a `child=` field. " + "Remove `source=` from the field declaration." + ) + + +@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None, + reason='RangeField is only available for django1.8+' + ' and with psycopg2.') +class TestDateRangeField(FieldValues): + """ + Values for `ListField` with CharField as child. + """ + if compat.DateRange is not None: + valid_inputs = [ + ({'lower': '2001-01-01', + 'upper': '2001-02-02', + 'bounds': '[)'}, + compat.DateRange( + **{'lower': datetime.date(2001, 1, 1), + 'upper': datetime.date(2001, 2, 2), + 'bounds': '[)'})), + ({'upper': '2001-02-02', + 'bounds': '[)'}, + compat.DateRange( + **{'upper': datetime.date(2001, 2, 2), + 'bounds': '[)'})), + ({'lower': '2001-01-01', + 'bounds': '[)'}, + compat.DateRange( + **{'lower': datetime.date(2001, 1, 1), + 'bounds': '[)'})), + ({'empty': True}, + compat.DateRange(**{'empty': True})), + ({}, compat.DateRange()), + ] + invalid_inputs = [ + ({'lower': 'a'}, ['Date has wrong format. Use one of these' + ' formats instead: ' + 'YYYY[-MM[-DD]].']), + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ] + outputs = [ + (compat.DateRange( + **{'lower': datetime.date(2001, 1, 1), + 'upper': datetime.date(2001, 2, 2)}), + {'lower': '2001-01-01', + 'upper': '2001-02-02', + 'bounds': '[)'}), + (compat.DateRange(**{'empty': True}), + {'empty': True}), + (compat.DateRange(), {'bounds': '[)', 'lower': None, 'upper': None}), + ] + field = serializers.DateRangeField() + + def test_no_source_on_child(self): + with pytest.raises(AssertionError) as exc_info: + serializers.DateRangeField(child=serializers.IntegerField(source='other')) + + assert str(exc_info.value) == ( + "The `source` argument is not meaningful when applied to a `child=` field. " + "Remove `source=` from the field declaration." + ) + + class TestJSONField(FieldValues): """ Values for `JSONField`.