Support UniqueConstraint (#7438)

This commit is contained in:
Konstantin Alekseev 2023-03-03 09:04:47 +02:00 committed by GitHub
parent 9882207c16
commit b7523f4b9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 169 additions and 40 deletions

View File

@ -1398,6 +1398,23 @@ class ModelSerializer(Serializer):
return extra_kwargs return extra_kwargs
def get_unique_together_constraints(self, model):
"""
Returns iterator of (fields, queryset), each entry describes an unique together
constraint on `fields` in `queryset`.
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
yield (
constraint.fields,
model._default_manager
if constraint.condition is None
else model._default_manager.filter(constraint.condition)
)
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
""" """
Return any additional field options that need to be included as a Return any additional field options that need to be included as a
@ -1426,10 +1443,9 @@ class ModelSerializer(Serializer):
unique_constraint_names -= {None} unique_constraint_names -= {None}
# Include each of the `unique_together` field names, # Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer. # so long as all the field names are included on the serializer.
for parent_class in [model] + list(model._meta.parents): for unique_together_list, queryset in self.get_unique_together_constraints(model):
for unique_together_list in parent_class._meta.unique_together:
if set(field_names).issuperset(unique_together_list): if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list) unique_constraint_names |= set(unique_together_list)
@ -1526,11 +1542,6 @@ class ModelSerializer(Serializer):
""" """
Determine a default set of validators for any unique_together constraints. Determine a default set of validators for any unique_together constraints.
""" """
model_class_inheritance_tree = (
[self.Meta.model] +
list(self.Meta.model._meta.parents)
)
# The field names we're passing though here only include fields # The field names we're passing though here only include fields
# which may map onto a model field. Any dotted field name lookups # which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not # cannot map to a field, and must be a traversal, so we're not
@ -1556,8 +1567,7 @@ class ModelSerializer(Serializer):
# Note that we make sure to check `unique_together` both on the # Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes. # base model class, but also on any parent classes.
validators = [] validators = []
for parent_class in model_class_inheritance_tree: for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
for unique_together in parent_class._meta.unique_together:
# Skip if serializer does not map to all unique together sources # Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together): if not set(source_map).issuperset(unique_together):
continue continue
@ -1580,7 +1590,7 @@ class ModelSerializer(Serializer):
field_names = tuple(source_map[f][0] for f in unique_together) field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator( validator = UniqueTogetherValidator(
queryset=parent_class._default_manager, queryset=queryset,
fields=field_names fields=field_names
) )
validators.append(validator) validators.append(validator)

View File

@ -62,6 +62,29 @@ def get_detail_view_name(model):
} }
def get_unique_validators(field_name, model_field):
"""
Returns a list of UniqueValidators that should be applied to the field.
"""
field_set = set([field_name])
conditions = {
c.condition
for c in model_field.model._meta.constraints
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
}
if getattr(model_field, 'unique', False):
conditions.add(None)
if not conditions:
return
unique_error_message = get_unique_error_message(model_field)
queryset = model_field.model._default_manager
for condition in conditions:
yield UniqueValidator(
queryset=queryset if condition is None else queryset.filter(condition),
message=unique_error_message
)
def get_field_kwargs(field_name, model_field): def get_field_kwargs(field_name, model_field):
""" """
Creates a default instance of a basic non-relational field. Creates a default instance of a basic non-relational field.
@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field):
if not isinstance(validator, validators.MinLengthValidator) if not isinstance(validator, validators.MinLengthValidator)
] ]
if getattr(model_field, 'unique', False): validator_kwarg += get_unique_validators(field_name, model_field)
validator = UniqueValidator(
queryset=model_field.model._default_manager,
message=get_unique_error_message(model_field))
validator_kwarg.append(validator)
if validator_kwarg: if validator_kwarg:
kwargs['validators'] = validator_kwarg kwargs['validators'] = validator_kwarg

View File

@ -464,6 +464,106 @@ class TestUniquenessTogetherValidation(TestCase):
assert queryset.called_with == {'race_name': 'bar', 'position': 1} assert queryset.called_with == {'race_name': 'bar', 'position': 1}
class UniqueConstraintModel(models.Model):
race_name = models.CharField(max_length=100)
position = models.IntegerField()
global_id = models.IntegerField()
fancy_conditions = models.IntegerField(null=True)
class Meta:
constraints = [
models.UniqueConstraint(
name="unique_constraint_model_global_id_uniq",
fields=('global_id',),
),
models.UniqueConstraint(
name="unique_constraint_model_fancy_1_uniq",
fields=('fancy_conditions',),
condition=models.Q(global_id__lte=1)
),
models.UniqueConstraint(
name="unique_constraint_model_fancy_3_uniq",
fields=('fancy_conditions',),
condition=models.Q(global_id__gte=3)
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
)
]
class UniqueConstraintSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = '__all__'
class TestUniqueConstraintValidation(TestCase):
def setUp(self):
self.instance = UniqueConstraintModel.objects.create(
race_name='example',
position=1,
global_id=1
)
UniqueConstraintModel.objects.create(
race_name='example',
position=2,
global_id=2
)
UniqueConstraintModel.objects.create(
race_name='other',
position=1,
global_id=3
)
def test_repr(self):
serializer = UniqueConstraintSerializer()
# the order of validators isn't deterministic so delete
# fancy_conditions field that has two of them
del serializer.fields['fancy_conditions']
expected = dedent("""
UniqueConstraintSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100, required=True)
position = IntegerField(required=True)
global_id = IntegerField(validators=[<UniqueValidator(queryset=UniqueConstraintModel.objects.all())>])
class Meta:
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
""")
assert repr(serializer) == expected
def test_unique_together_field(self):
"""
UniqueConstraint fields and condition attributes must be passed
to UniqueTogetherValidator as fields and queryset
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
)
def test_single_field_uniq_validators(self):
"""
UniqueConstraint with single field must be transformed into
field's UniqueValidator
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
validators = serializer.fields['global_id'].validators
assert len(validators) == 1
assert validators[0].queryset == UniqueConstraintModel.objects
validators = serializer.fields['fancy_conditions'].validators
assert len(validators) == 2
ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators}
assert ids_in_qs == {frozenset([1]), frozenset([3])}
# Tests for `UniqueForDateValidator` # Tests for `UniqueForDateValidator`
# ---------------------------------- # ----------------------------------