diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e939b2f29..82b7eb374 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -92,13 +92,35 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value +class CreateOnlyDefault: + """ + This class may be used to provide default values that are only used + for create operations, but that do not return any value for update + operations. + """ + def __init__(self, default): + self.default = default + + def set_context(self, serializer_field): + self.is_update = serializer_field.parent.instance is not None + + def __call__(self): + if self.is_update: + raise SkipField() + if callable(self.default): + return self.default() + return self.default + + def __repr__(self): + return '%s(%s)' % (self.__class__.__name__, repr(self.default)) + + class SkipField(Exception): pass NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' -NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' MISSING_ERROR_MESSAGE = ( @@ -132,7 +154,6 @@ class Field(object): # Some combinations of keyword arguments do not make sense. assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY assert not (read_only and required), NOT_READ_ONLY_REQUIRED - assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT assert not (required and default is not empty), NOT_REQUIRED_DEFAULT assert not (read_only and self.__class__ == Field), USE_READONLYFIELD @@ -230,7 +251,9 @@ class Field(object): """ if self.default is empty: raise SkipField() - if is_simple_callable(self.default): + if callable(self.default): + if hasattr(self.default, 'set_context'): + self.default.set_context(self) return self.default() return self.default @@ -244,6 +267,9 @@ class Field(object): May raise `SkipField` if the field should not be included in the validated data. """ + if self.read_only: + return self.get_default() + if data is empty: if getattr(self.root, 'partial', False): raise SkipField() @@ -1033,6 +1059,28 @@ class ReadOnlyField(Field): return value +class HiddenField(Field): + """ + A hidden field does not take input from the user, or present any output, + but it does populate a field in `validated_data`, based on its default + value. This is particularly useful when we have a `unique_for_date` + constrain on a pair of fields, as we need some way to include the date in + the validated data. + """ + def __init__(self, **kwargs): + assert 'default' in kwargs, 'default is a required argument.' + kwargs['write_only'] = True + super(HiddenField, self).__init__(**kwargs) + + def get_value(self, dictionary): + # We always use the default value for `HiddenField`. + # User input is never provided or accepted. + return empty + + def to_internal_value(self, data): + return data + + class SerializerMethodField(Field): """ A read-only field that get its representation from calling a method on the diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b45f343a4..6aab020ef 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,6 +12,7 @@ response content is handled by parsers and renderers. """ from django.core.exceptions import ImproperlyConfigured from django.db import models +from django.db.models.fields import FieldDoesNotExist from django.utils import six from django.utils.datastructures import SortedDict from rest_framework.exceptions import ValidationError @@ -368,7 +369,10 @@ class Serializer(BaseSerializer): """ ret = {} errors = ReturnDict(serializer=self) - fields = [field for field in self.fields.values() if not field.read_only] + fields = [ + field for field in self.fields.values() + if (not field.read_only) or (field.default is not empty) + ] for field in fields: validate_method = getattr(self, 'validate_' + field.field_name, None) @@ -517,7 +521,7 @@ class ModelSerializer(Serializer): def __init__(self, *args, **kwargs): super(ModelSerializer, self).__init__(*args, **kwargs) if 'validators' not in kwargs: - validators = self.get_unique_together_validators() + validators = self.get_default_validators() if validators: self.validators.extend(validators) self._kwargs['validators'] = validators @@ -572,7 +576,7 @@ class ModelSerializer(Serializer): instance.save() return instance - def get_unique_together_validators(self): + def get_default_validators(self): field_names = set([ field.source for field in self.fields.values() if (field.source != '*') and ('.' not in field.source) @@ -592,6 +596,7 @@ class ModelSerializer(Serializer): ) validators.append(validator) + # Add any unique_for_date/unique_for_month/unique_for_year constraints. info = model_meta.get_field_info(model_class) for field_name, field in info.fields_and_pk.items(): if field.unique_for_date and field_name in field_names: @@ -637,7 +642,7 @@ class ModelSerializer(Serializer): # Retrieve metadata about fields & relationships on the model class. info = model_meta.get_field_info(model) - # Use the default set of fields if none is supplied explicitly. + # Use the default set of field names if none is supplied explicitly. if fields is None: fields = self._get_default_field_names(declared_fields, info) exclude = getattr(self.Meta, 'exclude', None) @@ -645,6 +650,72 @@ class ModelSerializer(Serializer): for field_name in exclude: fields.remove(field_name) + # Determine the set of model fields, and the fields that they map to. + # We actually only need this to deal with the slightly awkward case + # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. + model_field_mapping = {} + for field_name in fields: + if field_name in declared_fields: + field = declared_fields[field_name] + source = field.source or field_name + else: + try: + source = extra_kwargs[field_name]['source'] + except KeyError: + source = field_name + # Model fields will always have a simple source mapping, + # they can't be nested attribute lookups. + if '.' not in source and source != '*': + model_field_mapping[source] = field_name + + # Determine if we need any additional `HiddenField` or extra keyword + # arguments to deal with `unique_for` dates that are required to + # be in the input data in order to validate it. + unique_fields = {} + for model_field_name, field_name in model_field_mapping.items(): + try: + model_field = model._meta.get_field(model_field_name) + except FieldDoesNotExist: + continue + + # Deal with each of the `unique_for_*` cases. + for date_field_name in ( + model_field.unique_for_date, + model_field.unique_for_month, + model_field.unique_for_year + ): + if date_field_name is None: + continue + + # Get the model field that is refered too. + date_field = model._meta.get_field(date_field_name) + + if date_field.auto_now_add: + default = CreateOnlyDefault(timezone.now) + elif date_field.auto_now: + default = timezone.now + elif date_field.has_default(): + default = model_field.default + else: + default = empty + + if date_field_name in model_field_mapping: + # The corresponding date field is present in the serializer + if date_field_name not in extra_kwargs: + extra_kwargs[date_field_name] = {} + if default is empty: + if 'required' not in extra_kwargs[date_field_name]: + extra_kwargs[date_field_name]['required'] = True + else: + if 'default' not in extra_kwargs[date_field_name]: + extra_kwargs[date_field_name]['default'] = default + else: + # The corresponding date field is not present in the, + # serializer. We have a default to use for the date, so + # add in a hidden field that populates it. + unique_fields[date_field_name] = HiddenField(default=default) + + # Now determine the fields that should be included on the serializer. for field_name in fields: if field_name in declared_fields: # Field is explicitly declared on the class, use that. @@ -723,6 +794,9 @@ class ModelSerializer(Serializer): # Create the serializer field. ret[field_name] = field_cls(**kwargs) + for field_name, field in unique_fields.items(): + ret[field_name] = field + return ret def _include_additional_options(self, extra_kwargs): diff --git a/tests/test_fields.py b/tests/test_fields.py index 6dc5f87d2..3e102ab5a 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -215,6 +215,33 @@ class TestBooleanHTMLInput: assert serializer.validated_data == {'archived': False} +class TestCreateOnlyDefault: + def setup(self): + default = serializers.CreateOnlyDefault('2001-01-01') + + class TestSerializer(serializers.Serializer): + published = serializers.HiddenField(default=default) + text = serializers.CharField() + self.Serializer = TestSerializer + + def test_create_only_default_is_provided(self): + serializer = self.Serializer(data={'text': 'example'}) + assert serializer.is_valid() + assert serializer.validated_data == { + 'text': 'example', 'published': '2001-01-01' + } + + def test_create_only_default_is_not_provided_on_update(self): + instance = { + 'text': 'example', 'published': '2001-01-01' + } + serializer = self.Serializer(instance, data={'text': 'example'}) + assert serializer.is_valid() + assert serializer.validated_data == { + 'text': 'example', + } + + # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_validators.py b/tests/test_validators.py index 5adb76783..6cc52c837 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -176,7 +176,7 @@ class TestUniquenessForDateValidation(TestCase): UniqueForDateSerializer(validators=[]): id = IntegerField(label='ID', read_only=True) slug = CharField(max_length=100) - published = DateField() + published = DateField(required=True) """) assert repr(serializer) == expected @@ -215,3 +215,40 @@ class TestUniquenessForDateValidation(TestCase): 'slug': 'existing', 'published': datetime.date(2000, 1, 1) } + + +class HiddenFieldUniqueForDateModel(models.Model): + slug = models.CharField(max_length=100, unique_for_date='published') + published = models.DateTimeField(auto_now_add=True) + + +class TestHiddenFieldUniquenessForDateValidation(TestCase): + def test_repr_date_field_not_included(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = HiddenFieldUniqueForDateModel + fields = ('id', 'slug') + + serializer = TestSerializer() + expected = dedent(""" + TestSerializer(validators=[]): + id = IntegerField(label='ID', read_only=True) + slug = CharField(max_length=100) + published = HiddenField(default=CreateOnlyDefault()) + """) + assert repr(serializer) == expected + + def test_repr_date_field_included(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = HiddenFieldUniqueForDateModel + fields = ('id', 'slug', 'published') + + serializer = TestSerializer() + expected = dedent(""" + TestSerializer(validators=[]): + id = IntegerField(label='ID', read_only=True) + slug = CharField(max_length=100) + published = DateTimeField(default=CreateOnlyDefault(), read_only=True) + """) + assert repr(serializer) == expected