diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 5cb096f1c..8d25d6c78 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -360,7 +360,10 @@ Corresponds to `django.db.models.fields.DurationField` The `validated_data` for these fields will contain a `datetime.timedelta` instance. The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`. -**Signature:** `DurationField()` +**Signature:** `DurationField(max_value=None, min_value=None)` + +- `max_value` Validate that the duration provided is no greater than this value. +- `min_value` Validate that the duration provided is no less than this value. --- diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c13279675..6b32fd496 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -903,20 +903,16 @@ class IPAddressField(CharField): # Number types... -class IntegerField(Field): +class MaxMinMixin(object): default_error_messages = { - 'invalid': _('A valid integer is required.'), 'max_value': _('Ensure this value is less than or equal to {max_value}.'), 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_string_length': _('String value too large.') } - MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' def __init__(self, **kwargs): self.max_value = kwargs.pop('max_value', None) self.min_value = kwargs.pop('min_value', None) - super(IntegerField, self).__init__(**kwargs) + super(MaxMinMixin, self).__init__(**kwargs) if self.max_value is not None: message = lazy( self.error_messages['max_value'].format, @@ -930,6 +926,15 @@ class IntegerField(Field): self.validators.append( MinValueValidator(self.min_value, message=message)) + +class IntegerField(MaxMinMixin, Field): + default_error_messages = { + 'invalid': _('A valid integer is required.'), + 'max_string_length': _('String value too large.') + } + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' + def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: self.fail('max_string_length') @@ -944,32 +949,13 @@ class IntegerField(Field): return int(value) -class FloatField(Field): +class FloatField(MaxMinMixin, Field): default_error_messages = { 'invalid': _('A valid number is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), 'max_string_length': _('String value too large.') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value', None) - self.min_value = kwargs.pop('min_value', None) - super(FloatField, self).__init__(**kwargs) - if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) - if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) - def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: @@ -984,11 +970,9 @@ class FloatField(Field): return float(value) -class DecimalField(Field): +class DecimalField(MaxMinMixin, Field): default_error_messages = { 'invalid': _('A valid number is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), @@ -1006,28 +990,12 @@ class DecimalField(Field): if self.localize: self.coerce_to_string = True - self.max_value = max_value - self.min_value = min_value - if self.max_digits is not None and self.decimal_places is not None: self.max_whole_digits = self.max_digits - self.decimal_places else: self.max_whole_digits = None - super(DecimalField, self).__init__(**kwargs) - - if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) - if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) + super(DecimalField, self).__init__(max_value=max_value, min_value=min_value, **kwargs) if rounding is not None: valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] @@ -1351,7 +1319,7 @@ class TimeField(Field): return value.strftime(output_format) -class DurationField(Field): +class DurationField(MaxMinMixin, Field): default_error_messages = { 'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'), } diff --git a/tests/test_fields.py b/tests/test_fields.py index 0ee49e9c1..1f099e0e3 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1459,6 +1459,23 @@ class TestNoOutputFormatTimeField(FieldValues): field = serializers.TimeField(format=None) +class TestMinMaxDurationField(FieldValues): + """ + Valid and invalid values for `IntegerField` with min and max limits. + """ + valid_inputs = { + '3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123), + 86401: datetime.timedelta(days=1, seconds=1), + } + invalid_inputs = { + 3600: ['Ensure this value is greater than or equal to 1 day, 0:00:00.'], + '4 08:32:01.000123': ['Ensure this value is less than or equal to 4 days, 0:00:00.'], + '3600': ['Ensure this value is greater than or equal to 1 day, 0:00:00.'], + } + outputs = {} + field = serializers.DurationField(min_value=datetime.timedelta(days=1), max_value=datetime.timedelta(days=4)) + + class TestDurationField(FieldValues): """ Valid and invalid values for `DurationField`.