diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 71a9f1938..670a21251 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -155,15 +155,41 @@ class Field(object): } default_validators = [] default_empty_html = empty - initial = None - def __init__(self, read_only=False, write_only=False, - required=None, default=empty, initial=empty, source=None, - label=None, help_text=None, style=None, - error_messages=None, validators=None, allow_null=False): + # allows subclasses to change defaults + read_only = False + write_only = False + required = None + default = empty + initial = None + source = None + label = None + help_text = None + style = None + error_messages = None + validators = None + allow_null = False + + def __init__(self, **kwargs): self._creation_counter = Field._creation_counter Field._creation_counter += 1 + # precedence is given to kwargs over class attributes. + # you can create a subclass of a Field for maximum reuse while still + # needing a one-off instance where you need to change an attribute. + read_only = kwargs.pop('read_only', self.read_only) + write_only = kwargs.pop('write_only', self.write_only) + required = kwargs.pop('required', self.required) + default = kwargs.pop('default', self.default) + initial = kwargs.pop('initial', self.initial) + source = kwargs.pop('source', self.source) + label = kwargs.pop('label', self.label) + help_text = kwargs.pop('help_text', self.help_text) + style = kwargs.pop('style', self.style) + error_messages = kwargs.pop('error_messages', self.error_messages) + validators = kwargs.pop('validators', self.validators) + allow_null = kwargs.pop('allow_null', self.allow_null) + # If `required` is unset, then use `True` unless a default is provided. if required is None: required = default is empty and not read_only @@ -179,7 +205,7 @@ class Field(object): self.required = required self.default = default self.source = source - self.initial = self.initial if (initial is empty) else initial + self.initial = initial self.label = label self.help_text = help_text self.style = {} if style is None else style @@ -552,18 +578,21 @@ class CharField(Field): 'min_length': _('Ensure this field has at least {min_length} characters.') } initial = '' + allow_blank = False + max_length = None + min_length = None def __init__(self, **kwargs): - self.allow_blank = kwargs.pop('allow_blank', False) - max_length = kwargs.pop('max_length', None) - min_length = kwargs.pop('min_length', None) + self.allow_blank = kwargs.pop('allow_blank', self.allow_blank) + self.max_length = kwargs.pop('max_length', self.max_length) + self.min_length = kwargs.pop('min_length', self.min_length) super(CharField, self).__init__(**kwargs) - if max_length is not None: - message = self.error_messages['max_length'].format(max_length=max_length) - self.validators.append(MaxLengthValidator(max_length, message=message)) - if min_length is not None: - message = self.error_messages['min_length'].format(min_length=min_length) - self.validators.append(MinLengthValidator(min_length, message=message)) + if self.max_length is not None: + message = self.error_messages['max_length'].format(max_length=self.max_length) + self.validators.append(MaxLengthValidator(self.max_length, message=message)) + if self.min_length is not None: + message = self.error_messages['min_length'].format(min_length=self.min_length) + self.validators.append(MinLengthValidator(self.min_length, message=message)) def run_validation(self, data=empty): # Test for the empty string here so that it does not get validated, @@ -660,17 +689,19 @@ class IntegerField(Field): 'max_string_length': _('String value too large') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + max_value = None + min_value = None def __init__(self, **kwargs): - max_value = kwargs.pop('max_value', None) - min_value = kwargs.pop('min_value', None) + self.max_value = kwargs.pop('max_value', self.max_value) + self.min_value = kwargs.pop('min_value', self.min_value) super(IntegerField, self).__init__(**kwargs) - if max_value is not None: - message = self.error_messages['max_value'].format(max_value=max_value) - self.validators.append(MaxValueValidator(max_value, message=message)) - if min_value is not None: - message = self.error_messages['min_value'].format(min_value=min_value) - self.validators.append(MinValueValidator(min_value, message=message)) + if self.max_value is not None: + message = self.error_messages['max_value'].format(max_value=self.max_value) + self.validators.append(MaxValueValidator(self.max_value, message=message)) + if self.min_value is not None: + message = self.error_messages['min_value'].format(min_value=self.min_value) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: @@ -694,17 +725,19 @@ class FloatField(Field): 'max_string_length': _('String value too large') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + max_value = None + min_value = None def __init__(self, **kwargs): - max_value = kwargs.pop('max_value', None) - min_value = kwargs.pop('min_value', None) + self.max_value = kwargs.pop('max_value', self.max_value) + self.min_value = kwargs.pop('min_value', self.min_value) super(FloatField, self).__init__(**kwargs) - if max_value is not None: - message = self.error_messages['max_value'].format(max_value=max_value) - self.validators.append(MaxValueValidator(max_value, message=message)) - if min_value is not None: - message = self.error_messages['min_value'].format(min_value=min_value) - self.validators.append(MinValueValidator(min_value, message=message)) + if self.max_value is not None: + message = self.error_messages['max_value'].format(max_value=self.max_value) + self.validators.append(MaxValueValidator(self.max_value, message=message)) + if self.min_value is not None: + message = self.error_messages['min_value'].format(min_value=self.min_value) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: @@ -730,20 +763,26 @@ class DecimalField(Field): 'max_string_length': _('String value too large') } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + max_value = None + min_value = None + # todo: max_digits = None + # todo: decimal_places = None coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING - def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): + def __init__(self, max_digits, decimal_places, coerce_to_string=None, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string + self.max_value = kwargs.pop('max_value', self.max_value) + self.min_value = kwargs.pop('min_value', self.min_value) super(DecimalField, self).__init__(**kwargs) - if max_value is not None: - message = self.error_messages['max_value'].format(max_value=max_value) - self.validators.append(MaxValueValidator(max_value, message=message)) - if min_value is not None: - message = self.error_messages['min_value'].format(min_value=min_value) - self.validators.append(MinValueValidator(min_value, message=message)) + if self.max_value is not None: + message = self.error_messages['max_value'].format(max_value=self.max_value) + self.validators.append(MaxValueValidator(self.max_value, message=message)) + if self.min_value is not None: + message = self.error_messages['min_value'].format(min_value=self.min_value) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): """ @@ -817,10 +856,10 @@ class DateTimeField(Field): input_formats = api_settings.DATETIME_INPUT_FORMATS default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None - def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone + def __init__(self, *args, **kwargs): + self.format = kwargs.pop('format', self.format) + self.default_timezone = kwargs.pop('default_timezone', self.default_timezone) + self.input_formats = kwargs.pop('input_formats', self.input_formats) super(DateTimeField, self).__init__(*args, **kwargs) def enforce_timezone(self, value): @@ -881,9 +920,9 @@ class DateField(Field): format = api_settings.DATE_FORMAT input_formats = api_settings.DATE_INPUT_FORMATS - def __init__(self, format=empty, input_formats=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats + def __init__(self, *args, **kwargs): + self.format = kwargs.pop('format', self.format) + self.input_formats = kwargs.pop('input_formats', self.input_formats) super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): @@ -938,9 +977,9 @@ class TimeField(Field): format = api_settings.TIME_FORMAT input_formats = api_settings.TIME_INPUT_FORMATS - def __init__(self, format=empty, input_formats=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats + def __init__(self, *args, **kwargs): + self.format = kwargs.pop('format', self.format) + self.input_formats = kwargs.pop('input_formats', self.input_formats) super(TimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): @@ -991,8 +1030,17 @@ class ChoiceField(Field): default_error_messages = { 'invalid_choice': _('`{input}` is not a valid choice.') } + allow_blank = False + choices = None + + def __init__(self, *args, **kwargs): + if args: + choices = args[0] + else: + choices = kwargs.pop('choices', self.choices) + # not available on class or as kwarg + assert choices is not None, 'need to specify `choices`.' - def __init__(self, choices, **kwargs): # Allow either single or paired choices style: # choices = [1, 2, 3] # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] @@ -1012,7 +1060,7 @@ class ChoiceField(Field): (six.text_type(key), key) for key in self.choices.keys() ]) - self.allow_blank = kwargs.pop('allow_blank', False) + self.allow_blank = kwargs.pop('allow_blank', self.allow_blank) super(ChoiceField, self).__init__(**kwargs) @@ -1071,10 +1119,12 @@ class FileField(Field): 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), } use_url = api_settings.UPLOADED_FILES_USE_URL + max_length = None + allow_empty_file = False def __init__(self, *args, **kwargs): - self.max_length = kwargs.pop('max_length', None) - self.allow_empty_file = kwargs.pop('allow_empty_file', False) + self.max_length = kwargs.pop('max_length', self.max_length) + self.allow_empty_file = kwargs.pop('allow_empty_file', self.allow_empty_file) self.use_url = kwargs.pop('use_url', self.use_url) super(FileField, self).__init__(*args, **kwargs) @@ -1324,16 +1374,17 @@ class ModelField(Field): default_error_messages = { 'max_length': _('Ensure this field has no more than {max_length} characters.'), } + max_length = None def __init__(self, model_field, **kwargs): self.model_field = model_field # The `max_length` option is supported by Django's base `Field` class, # so we'd better support it here. - max_length = kwargs.pop('max_length', None) + self.max_length = kwargs.pop('max_length', self.max_length) super(ModelField, self).__init__(**kwargs) - if max_length is not None: - message = self.error_messages['max_length'].format(max_length=max_length) - self.validators.append(MaxLengthValidator(max_length, message=message)) + if self.max_length is not None: + message = self.error_messages['max_length'].format(max_length=self.max_length) + self.validators.append(MaxLengthValidator(self.max_length, message=message)) def to_internal_value(self, data): rel = getattr(self.model_field, 'rel', None)