diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4..be1cadb43 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -3,3 +3,4 @@ markdown==2.6.4 django-guardian==1.4.3 django-filter==0.13.0 coreapi==1.32.0 +psycopg2==2.6.2 diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 917a151e5..66f506a69 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1590,6 +1590,98 @@ class JSONField(Field): return value +class RangeField(Field): + initial = {} + default_error_messages = { + 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".'), + 'unexpected_keys': _('Got unexpected keys "{unexpected}".'), + 'invalid_bounds': _('Bounds flags "{bounds}" not valid. Valid bounds are "{valid_bounds}".'), + 'empty': _('Range may not be empty.'), + } + + valid_bounds = ('[)', '(]', '()', '[]') + + def __init__(self, range_type, **kwargs): + self.child = kwargs.pop('child') + self.range_type = range_type + self.allow_empty = kwargs.pop('allow_empty', True) + + assert not inspect.isclass(self.child), '`child` has not been instantiated.' + assert self.child.source is None, ( + "The `source` argument is not meaningful when applied to a `child=` field. " + "Remove `source=` from the field declaration." + ) + + super(RangeField, self).__init__(**kwargs) + self.child.bind(field_name='', parent=self) + + def _valid_empty_range(self, data): + if not data.pop('empty', False): + return False + if not self.allow_empty: + self.fail('empty') + return True + + def _validate_bounds(self, data): + try: + bounds = data.pop('bounds') + except KeyError: + return + if bounds not in self.valid_bounds: + self.fail('invalid_bounds', bounds=bounds, valid_bounds=', '.join(self.valid_bounds)) + return bounds + + def _validate_ranges(self, data): + errors, validated_data = {}, {} + for key in ('lower', 'upper'): + try: + value = data.pop(key) + except KeyError: + continue + else: + try: + validated_data[key] = self.child.run_validation(value) + except ValidationError as e: + errors[key] = e.detail + + if errors: + raise ValidationError(errors) + + return validated_data + + def to_internal_value(self, data): + if isinstance(data, self.range_type): + return data + + if not isinstance(data, dict): + self.fail('not_a_dict', input_type=type(data).__name__) + + if self._valid_empty_range(data): + return self.range_type(empty=True) + + validated_data = self._validate_ranges(data) + bounds = self._validate_bounds(data) + if bounds: + validated_data['bounds'] = bounds + + if data: + self.fail('unexpected_keys', unexpected=', '.join(map(str, data.keys()))) + + return self.range_type(**validated_data) + + def to_representation(self, value): + if value is None: + return value + + lower = self.child.to_representation(value.lower) if value.lower is not None else None + upper = self.child.to_representation(value.upper) if value.upper is not None else None + + if value.isempty: + return {'empty': True} + + return {'lower': lower, 'upper': upper, 'bounds': value._bounds} + + # Miscellaneous field types... class ReadOnlyField(Field): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63ae..387673a86 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1140,7 +1140,7 @@ class ModelSerializer(Serializer): # `allow_blank` is only valid for textual fields. field_kwargs.pop('allow_blank', None) - if postgres_fields and isinstance(model_field, postgres_fields.ArrayField): + if postgres_fields and isinstance(model_field, (postgres_fields.ArrayField, postgres_fields.RangeField)): # Populate the `child` argument on `ListField` instances generated # for the PostgrSQL specfic `ArrayField`. child_model_field = model_field.base_field @@ -1149,6 +1149,9 @@ class ModelSerializer(Serializer): ) field_kwargs['child'] = child_field_class(**child_field_kwargs) + if isinstance(model_field, postgres_fields.RangeField): + field_kwargs['range_type'] = model_field.range_type + return field_class, field_kwargs def build_relational_field(self, field_name, relation_info): @@ -1469,6 +1472,7 @@ if postgres_fields: ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField + ModelSerializer.serializer_field_mapping[postgres_fields.RangeField] = RangeField class HyperlinkedModelSerializer(ModelSerializer): diff --git a/tests/test_fields.py b/tests/test_fields.py index 4a4b741c5..3de679812 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -10,7 +10,7 @@ from django.test import TestCase, override_settings from django.utils import six, timezone import rest_framework -from rest_framework import serializers +from rest_framework import compat, serializers # Tests for field keyword arguments and core functionality. @@ -1715,6 +1715,37 @@ class TestBinaryJSONField(FieldValues): field = serializers.JSONField(binary=True) +if compat.postgres_fields: + from psycopg2.extras import DateTimeTZRange + + class TestRangeField(FieldValues): + """ + Values for `ListField` with no `child` argument. + """ + valid_inputs = [ + ({'lower': '2016-01-01T00:30:01', 'upper': '2016-01-01T01:00', 'bounds': '[]', 'empty': False}, + DateTimeTZRange(datetime.datetime(2016, 1, 1, 0, 30, 1), datetime.datetime(2016, 1, 1, 1, 0), '[]')), + ({'lower': '2016-01-01T00:30:01'}, DateTimeTZRange(datetime.datetime(2016, 1, 1, 0, 30, 1), None, '[)')), + ({'upper': '2016-01-01T00:00'}, DateTimeTZRange(None, datetime.datetime(2016, 1, 1, 0, 0, 0), '[)')), + ({'empty': True}, DateTimeTZRange(empty=True)), + (DateTimeTZRange(None, datetime.datetime(2016, 1, 1, 0, 0, 0), '[)'), DateTimeTZRange(None, datetime.datetime(2016, 1, 1, 0, 0, 0), '[)')) + ] + invalid_inputs = [ + ('not a dict', ['Expected a dictionary of items but got type "str".']), + (['not a dict'], ['Expected a dictionary of items but got type "list".']), + ({'lower': 0}, {'lower': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].']}), + ({'bounds': '['}, ['Bounds flags "[" not valid. Valid bounds are "[), (], (), []".']), + ({'unexpected': '[]'}, ['Got unexpected keys "unexpected".']), + ] + outputs = [ + (DateTimeTZRange(datetime.datetime(2016, 1, 1, 0, 30, 1), datetime.datetime(2016, 1, 1, 1, 0), '[]'), + {'lower': '2016-01-01T00:30:01', 'upper': '2016-01-01T01:00:00', 'bounds': '[]'}), + (DateTimeTZRange(empty=True), {'empty': True}), + (None, None), + ] + field = serializers.RangeField(range_type=DateTimeTZRange, child=serializers.DateTimeField()) + + # Tests for FieldField. # --------------------- diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 01243ff6e..a66265c33 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -19,7 +19,7 @@ from django.db.models import DurationField as ModelDurationField from django.test import TestCase from django.utils import six -from rest_framework import serializers +from rest_framework import compat, serializers from rest_framework.compat import unicode_repr @@ -382,6 +382,23 @@ class TestGenericIPAddressFieldValidation(TestCase): '{0}'.format(s.errors)) +if compat.postgres_fields: + class TestRangeFieldMapping(TestCase): + def test_date_range_field(self): + class DateRangeFieldModel(models.Model): + timestamps = compat.postgres_fields.DateTimeRangeField() + + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DateRangeFieldModel + fields = serializers.ALL_FIELDS + extra_kwargs = {'timestamps': {'allow_empty': False}} + + s = TestSerializer(data={'timestamps': {'empty': True}}) + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors['timestamps'], ['Range may not be empty.']) + + # Tests for relational field mappings. # ------------------------------------