diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 84282cdb2..2e34dbe75 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -720,49 +720,60 @@ class ModelSerializer(Serializer): # Determine if we need any additional `HiddenField` or extra keyword # arguments to deal with `unique_for` dates that are required to # be in the input data in order to validate it. - unique_fields = {} + hidden_fields = {} + for model_field_name, field_name in model_field_mapping.items(): try: model_field = model._meta.get_field(model_field_name) except FieldDoesNotExist: continue - # Deal with each of the `unique_for_*` cases. - for date_field_name in ( + # Include each of the `unique_for_*` field names. + unique_constraint_names = set([ model_field.unique_for_date, model_field.unique_for_month, model_field.unique_for_year - ): - if date_field_name is None: - continue + ]) + unique_constraint_names -= set([None]) - # Get the model field that is refered too. - date_field = model._meta.get_field(date_field_name) + # Include each of the `unique_together` field names, + # so long as all the field names are included on the serializer. + for parent_class in [model] + list(model._meta.parents.keys()): + for unique_together_list in parent_class._meta.unique_together: + if set(fields).issuperset(set(unique_together_list)): + unique_constraint_names |= set(unique_together_list) - if date_field.auto_now_add: - default = CreateOnlyDefault(timezone.now) - elif date_field.auto_now: - default = timezone.now - elif date_field.has_default(): - default = model_field.default + # Now we have all the field names that have uniqueness constraints + # applied, we can add the extra 'required=...' or 'default=...' + # arguments that are appropriate to these fields, or add a `HiddenField` for it. + for unique_constraint_name in unique_constraint_names: + # Get the model field that is refered too. + unique_constraint_field = model._meta.get_field(unique_constraint_name) + + if getattr(unique_constraint_field, 'auto_now_add', None): + default = CreateOnlyDefault(timezone.now) + elif getattr(unique_constraint_field, 'auto_now', None): + default = timezone.now + elif unique_constraint_field.has_default(): + default = model_field.default + else: + default = empty + + if unique_constraint_name in model_field_mapping: + # The corresponding field is present in the serializer + if unique_constraint_name not in extra_kwargs: + extra_kwargs[unique_constraint_name] = {} + if default is empty: + if 'required' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['required'] = True else: - default = empty - - if date_field_name in model_field_mapping: - # The corresponding date field is present in the serializer - if date_field_name not in extra_kwargs: - extra_kwargs[date_field_name] = {} - if default is empty: - if 'required' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['required'] = True - else: - if 'default' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['default'] = default - else: - # The corresponding date field is not present in the, - # serializer. We have a default to use for the date, so - # add in a hidden field that populates it. - unique_fields[date_field_name] = HiddenField(default=default) + if 'default' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['default'] = default + elif default is not empty: + # The corresponding field is not present in the, + # serializer. We have a default to use for it, so + # add in a hidden field that populates it. + hidden_fields[unique_constraint_name] = HiddenField(default=default) # Now determine the fields that should be included on the serializer. for field_name in fields: @@ -838,12 +849,16 @@ class ModelSerializer(Serializer): 'validators', 'queryset' ]: kwargs.pop(attr, None) + + if extras.get('default') and kwargs.get('required') is False: + kwargs.pop('required') + kwargs.update(extras) # Create the serializer field. ret[field_name] = field_cls(**kwargs) - for field_name, field in unique_fields.items(): + for field_name, field in hidden_fields.items(): ret[field_name] = field return ret diff --git a/rest_framework/validators.py b/rest_framework/validators.py index fa4f18474..7ca4e6a9b 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -93,6 +93,9 @@ class UniqueTogetherValidator: The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ + if self.instance is not None: + return + missing = dict([ (field_name, self.missing_message) for field_name in self.fields @@ -105,8 +108,17 @@ class UniqueTogetherValidator: """ Filter the queryset to all instances matching the given attributes. """ + # If this is an update, then any unprovided field should + # have it's value set based on the existing instance attribute. + if self.instance is not None: + for field_name in self.fields: + if field_name not in attrs: + attrs[field_name] = getattr(self.instance, field_name) + + # Determine the filter keyword arguments and filter the queryset. filter_kwargs = dict([ - (field_name, attrs[field_name]) for field_name in self.fields + (field_name, attrs[field_name]) + for field_name in self.fields ]) return queryset.filter(**filter_kwargs) diff --git a/tests/test_validators.py b/tests/test_validators.py index 86614b109..1df0641c5 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -88,8 +88,8 @@ class TestUniquenessTogetherValidation(TestCase): expected = dedent(""" UniquenessTogetherSerializer(): id = IntegerField(label='ID', read_only=True) - race_name = CharField(max_length=100) - position = IntegerField() + race_name = CharField(max_length=100, required=True) + position = IntegerField(required=True) class Meta: validators = [] """)