diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eae6a0b2e..e27f8a47c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1398,6 +1398,23 @@ class ModelSerializer(Serializer): return extra_kwargs + def get_unique_together_constraints(self, model): + """ + Returns iterator of (fields, queryset), each entry describes an unique together + constraint on `fields` in `queryset`. + """ + for parent_class in [model] + list(model._meta.parents): + for unique_together in parent_class._meta.unique_together: + yield unique_together, model._default_manager + for constraint in parent_class._meta.constraints: + if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: + yield ( + constraint.fields, + model._default_manager + if constraint.condition is None + else model._default_manager.filter(constraint.condition) + ) + def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ Return any additional field options that need to be included as a @@ -1426,12 +1443,11 @@ class ModelSerializer(Serializer): unique_constraint_names -= {None} - # Include each of the `unique_together` field names, + # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for parent_class in [model] + list(model._meta.parents): - for unique_together_list in parent_class._meta.unique_together: - if set(field_names).issuperset(unique_together_list): - unique_constraint_names |= set(unique_together_list) + for unique_together_list, queryset in self.get_unique_together_constraints(model): + if set(field_names).issuperset(unique_together_list): + unique_constraint_names |= set(unique_together_list) # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1526,11 +1542,6 @@ class ModelSerializer(Serializer): """ Determine a default set of validators for any unique_together constraints. """ - model_class_inheritance_tree = ( - [self.Meta.model] + - list(self.Meta.model._meta.parents) - ) - # The field names we're passing though here only include fields # which may map onto a model field. Any dotted field name lookups # cannot map to a field, and must be a traversal, so we're not @@ -1556,34 +1567,33 @@ class ModelSerializer(Serializer): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for parent_class in model_class_inheritance_tree: - for unique_together in parent_class._meta.unique_together: - # Skip if serializer does not map to all unique together sources - if not set(source_map).issuperset(unique_together): - continue + for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): + # Skip if serializer does not map to all unique together sources + if not set(source_map).issuperset(unique_together): + continue - for source in unique_together: - assert len(source_map[source]) == 1, ( - "Unable to create `UniqueTogetherValidator` for " - "`{model}.{field}` as `{serializer}` has multiple " - "fields ({fields}) that map to this model field. " - "Either remove the extra fields, or override " - "`Meta.validators` with a `UniqueTogetherValidator` " - "using the desired field names." - .format( - model=self.Meta.model.__name__, - serializer=self.__class__.__name__, - field=source, - fields=', '.join(source_map[source]), - ) + for source in unique_together: + assert len(source_map[source]) == 1, ( + "Unable to create `UniqueTogetherValidator` for " + "`{model}.{field}` as `{serializer}` has multiple " + "fields ({fields}) that map to this model field. " + "Either remove the extra fields, or override " + "`Meta.validators` with a `UniqueTogetherValidator` " + "using the desired field names." + .format( + model=self.Meta.model.__name__, + serializer=self.__class__.__name__, + field=source, + fields=', '.join(source_map[source]), ) - - field_names = tuple(source_map[f][0] for f in unique_together) - validator = UniqueTogetherValidator( - queryset=parent_class._default_manager, - fields=field_names ) - validators.append(validator) + + field_names = tuple(source_map[f][0] for f in unique_together) + validator = UniqueTogetherValidator( + queryset=queryset, + fields=field_names + ) + validators.append(validator) return validators def get_unique_for_date_validators(self): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 7e8e8f046..fc63f96fe 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -62,6 +62,29 @@ def get_detail_view_name(model): } +def get_unique_validators(field_name, model_field): + """ + Returns a list of UniqueValidators that should be applied to the field. + """ + field_set = set([field_name]) + conditions = { + c.condition + for c in model_field.model._meta.constraints + if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set + } + if getattr(model_field, 'unique', False): + conditions.add(None) + if not conditions: + return + unique_error_message = get_unique_error_message(model_field) + queryset = model_field.model._default_manager + for condition in conditions: + yield UniqueValidator( + queryset=queryset if condition is None else queryset.filter(condition), + message=unique_error_message + ) + + def get_field_kwargs(field_name, model_field): """ Creates a default instance of a basic non-relational field. @@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field): if not isinstance(validator, validators.MinLengthValidator) ] - if getattr(model_field, 'unique', False): - validator = UniqueValidator( - queryset=model_field.model._default_manager, - message=get_unique_error_message(model_field)) - validator_kwarg.append(validator) + validator_kwarg += get_unique_validators(field_name, model_field) if validator_kwarg: kwargs['validators'] = validator_kwarg diff --git a/tests/test_validators.py b/tests/test_validators.py index 39490ac86..35fef6f26 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -464,6 +464,106 @@ class TestUniquenessTogetherValidation(TestCase): assert queryset.called_with == {'race_name': 'bar', 'position': 1} +class UniqueConstraintModel(models.Model): + race_name = models.CharField(max_length=100) + position = models.IntegerField() + global_id = models.IntegerField() + fancy_conditions = models.IntegerField(null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + name="unique_constraint_model_global_id_uniq", + fields=('global_id',), + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_1_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__lte=1) + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_3_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__gte=3) + ), + models.UniqueConstraint( + name="unique_constraint_model_together_uniq", + fields=('race_name', 'position'), + condition=models.Q(race_name='example'), + ) + ] + + +class UniqueConstraintSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintModel + fields = '__all__' + + +class TestUniqueConstraintValidation(TestCase): + def setUp(self): + self.instance = UniqueConstraintModel.objects.create( + race_name='example', + position=1, + global_id=1 + ) + UniqueConstraintModel.objects.create( + race_name='example', + position=2, + global_id=2 + ) + UniqueConstraintModel.objects.create( + race_name='other', + position=1, + global_id=3 + ) + + def test_repr(self): + serializer = UniqueConstraintSerializer() + # the order of validators isn't deterministic so delete + # fancy_conditions field that has two of them + del serializer.fields['fancy_conditions'] + expected = dedent(""" + UniqueConstraintSerializer(): + id = IntegerField(label='ID', read_only=True) + race_name = CharField(max_length=100, required=True) + position = IntegerField(required=True) + global_id = IntegerField(validators=[]) + class Meta: + validators = [, ]>, fields=('race_name', 'position'))>] + """) + assert repr(serializer) == expected + + def test_unique_together_field(self): + """ + UniqueConstraint fields and condition attributes must be passed + to UniqueTogetherValidator as fields and queryset + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validator = serializer.validators[0] + assert validator.fields == ('race_name', 'position') + assert set(validator.queryset.values_list(flat=True)) == set( + UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) + ) + + def test_single_field_uniq_validators(self): + """ + UniqueConstraint with single field must be transformed into + field's UniqueValidator + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validators = serializer.fields['global_id'].validators + assert len(validators) == 1 + assert validators[0].queryset == UniqueConstraintModel.objects + + validators = serializer.fields['fancy_conditions'].validators + assert len(validators) == 2 + ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators} + assert ids_in_qs == {frozenset([1]), frozenset([3])} + + # Tests for `UniqueForDateValidator` # ----------------------------------