mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-22 01:26:53 +03:00
Support UniqueConstraint (#7438)
This commit is contained in:
parent
9882207c16
commit
b7523f4b9f
|
@ -1398,6 +1398,23 @@ class ModelSerializer(Serializer):
|
||||||
|
|
||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
def get_unique_together_constraints(self, model):
|
||||||
|
"""
|
||||||
|
Returns iterator of (fields, queryset), each entry describes an unique together
|
||||||
|
constraint on `fields` in `queryset`.
|
||||||
|
"""
|
||||||
|
for parent_class in [model] + list(model._meta.parents):
|
||||||
|
for unique_together in parent_class._meta.unique_together:
|
||||||
|
yield unique_together, model._default_manager
|
||||||
|
for constraint in parent_class._meta.constraints:
|
||||||
|
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
|
||||||
|
yield (
|
||||||
|
constraint.fields,
|
||||||
|
model._default_manager
|
||||||
|
if constraint.condition is None
|
||||||
|
else model._default_manager.filter(constraint.condition)
|
||||||
|
)
|
||||||
|
|
||||||
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
|
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
|
||||||
"""
|
"""
|
||||||
Return any additional field options that need to be included as a
|
Return any additional field options that need to be included as a
|
||||||
|
@ -1426,12 +1443,11 @@ class ModelSerializer(Serializer):
|
||||||
|
|
||||||
unique_constraint_names -= {None}
|
unique_constraint_names -= {None}
|
||||||
|
|
||||||
# Include each of the `unique_together` field names,
|
# Include each of the `unique_together` and `UniqueConstraint` field names,
|
||||||
# so long as all the field names are included on the serializer.
|
# so long as all the field names are included on the serializer.
|
||||||
for parent_class in [model] + list(model._meta.parents):
|
for unique_together_list, queryset in self.get_unique_together_constraints(model):
|
||||||
for unique_together_list in parent_class._meta.unique_together:
|
if set(field_names).issuperset(unique_together_list):
|
||||||
if set(field_names).issuperset(unique_together_list):
|
unique_constraint_names |= set(unique_together_list)
|
||||||
unique_constraint_names |= set(unique_together_list)
|
|
||||||
|
|
||||||
# Now we have all the field names that have uniqueness constraints
|
# Now we have all the field names that have uniqueness constraints
|
||||||
# applied, we can add the extra 'required=...' or 'default=...'
|
# applied, we can add the extra 'required=...' or 'default=...'
|
||||||
|
@ -1526,11 +1542,6 @@ class ModelSerializer(Serializer):
|
||||||
"""
|
"""
|
||||||
Determine a default set of validators for any unique_together constraints.
|
Determine a default set of validators for any unique_together constraints.
|
||||||
"""
|
"""
|
||||||
model_class_inheritance_tree = (
|
|
||||||
[self.Meta.model] +
|
|
||||||
list(self.Meta.model._meta.parents)
|
|
||||||
)
|
|
||||||
|
|
||||||
# The field names we're passing though here only include fields
|
# The field names we're passing though here only include fields
|
||||||
# which may map onto a model field. Any dotted field name lookups
|
# which may map onto a model field. Any dotted field name lookups
|
||||||
# cannot map to a field, and must be a traversal, so we're not
|
# cannot map to a field, and must be a traversal, so we're not
|
||||||
|
@ -1556,34 +1567,33 @@ class ModelSerializer(Serializer):
|
||||||
# Note that we make sure to check `unique_together` both on the
|
# Note that we make sure to check `unique_together` both on the
|
||||||
# base model class, but also on any parent classes.
|
# base model class, but also on any parent classes.
|
||||||
validators = []
|
validators = []
|
||||||
for parent_class in model_class_inheritance_tree:
|
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
|
||||||
for unique_together in parent_class._meta.unique_together:
|
# Skip if serializer does not map to all unique together sources
|
||||||
# Skip if serializer does not map to all unique together sources
|
if not set(source_map).issuperset(unique_together):
|
||||||
if not set(source_map).issuperset(unique_together):
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
for source in unique_together:
|
for source in unique_together:
|
||||||
assert len(source_map[source]) == 1, (
|
assert len(source_map[source]) == 1, (
|
||||||
"Unable to create `UniqueTogetherValidator` for "
|
"Unable to create `UniqueTogetherValidator` for "
|
||||||
"`{model}.{field}` as `{serializer}` has multiple "
|
"`{model}.{field}` as `{serializer}` has multiple "
|
||||||
"fields ({fields}) that map to this model field. "
|
"fields ({fields}) that map to this model field. "
|
||||||
"Either remove the extra fields, or override "
|
"Either remove the extra fields, or override "
|
||||||
"`Meta.validators` with a `UniqueTogetherValidator` "
|
"`Meta.validators` with a `UniqueTogetherValidator` "
|
||||||
"using the desired field names."
|
"using the desired field names."
|
||||||
.format(
|
.format(
|
||||||
model=self.Meta.model.__name__,
|
model=self.Meta.model.__name__,
|
||||||
serializer=self.__class__.__name__,
|
serializer=self.__class__.__name__,
|
||||||
field=source,
|
field=source,
|
||||||
fields=', '.join(source_map[source]),
|
fields=', '.join(source_map[source]),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
field_names = tuple(source_map[f][0] for f in unique_together)
|
|
||||||
validator = UniqueTogetherValidator(
|
|
||||||
queryset=parent_class._default_manager,
|
|
||||||
fields=field_names
|
|
||||||
)
|
)
|
||||||
validators.append(validator)
|
|
||||||
|
field_names = tuple(source_map[f][0] for f in unique_together)
|
||||||
|
validator = UniqueTogetherValidator(
|
||||||
|
queryset=queryset,
|
||||||
|
fields=field_names
|
||||||
|
)
|
||||||
|
validators.append(validator)
|
||||||
return validators
|
return validators
|
||||||
|
|
||||||
def get_unique_for_date_validators(self):
|
def get_unique_for_date_validators(self):
|
||||||
|
|
|
@ -62,6 +62,29 @@ def get_detail_view_name(model):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_validators(field_name, model_field):
|
||||||
|
"""
|
||||||
|
Returns a list of UniqueValidators that should be applied to the field.
|
||||||
|
"""
|
||||||
|
field_set = set([field_name])
|
||||||
|
conditions = {
|
||||||
|
c.condition
|
||||||
|
for c in model_field.model._meta.constraints
|
||||||
|
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
|
||||||
|
}
|
||||||
|
if getattr(model_field, 'unique', False):
|
||||||
|
conditions.add(None)
|
||||||
|
if not conditions:
|
||||||
|
return
|
||||||
|
unique_error_message = get_unique_error_message(model_field)
|
||||||
|
queryset = model_field.model._default_manager
|
||||||
|
for condition in conditions:
|
||||||
|
yield UniqueValidator(
|
||||||
|
queryset=queryset if condition is None else queryset.filter(condition),
|
||||||
|
message=unique_error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_field_kwargs(field_name, model_field):
|
def get_field_kwargs(field_name, model_field):
|
||||||
"""
|
"""
|
||||||
Creates a default instance of a basic non-relational field.
|
Creates a default instance of a basic non-relational field.
|
||||||
|
@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field):
|
||||||
if not isinstance(validator, validators.MinLengthValidator)
|
if not isinstance(validator, validators.MinLengthValidator)
|
||||||
]
|
]
|
||||||
|
|
||||||
if getattr(model_field, 'unique', False):
|
validator_kwarg += get_unique_validators(field_name, model_field)
|
||||||
validator = UniqueValidator(
|
|
||||||
queryset=model_field.model._default_manager,
|
|
||||||
message=get_unique_error_message(model_field))
|
|
||||||
validator_kwarg.append(validator)
|
|
||||||
|
|
||||||
if validator_kwarg:
|
if validator_kwarg:
|
||||||
kwargs['validators'] = validator_kwarg
|
kwargs['validators'] = validator_kwarg
|
||||||
|
|
|
@ -464,6 +464,106 @@ class TestUniquenessTogetherValidation(TestCase):
|
||||||
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
|
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueConstraintModel(models.Model):
|
||||||
|
race_name = models.CharField(max_length=100)
|
||||||
|
position = models.IntegerField()
|
||||||
|
global_id = models.IntegerField()
|
||||||
|
fancy_conditions = models.IntegerField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name="unique_constraint_model_global_id_uniq",
|
||||||
|
fields=('global_id',),
|
||||||
|
),
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name="unique_constraint_model_fancy_1_uniq",
|
||||||
|
fields=('fancy_conditions',),
|
||||||
|
condition=models.Q(global_id__lte=1)
|
||||||
|
),
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name="unique_constraint_model_fancy_3_uniq",
|
||||||
|
fields=('fancy_conditions',),
|
||||||
|
condition=models.Q(global_id__gte=3)
|
||||||
|
),
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name="unique_constraint_model_together_uniq",
|
||||||
|
fields=('race_name', 'position'),
|
||||||
|
condition=models.Q(race_name='example'),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueConstraintSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = UniqueConstraintModel
|
||||||
|
fields = '__all__'
|
||||||
|
|
||||||
|
|
||||||
|
class TestUniqueConstraintValidation(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.instance = UniqueConstraintModel.objects.create(
|
||||||
|
race_name='example',
|
||||||
|
position=1,
|
||||||
|
global_id=1
|
||||||
|
)
|
||||||
|
UniqueConstraintModel.objects.create(
|
||||||
|
race_name='example',
|
||||||
|
position=2,
|
||||||
|
global_id=2
|
||||||
|
)
|
||||||
|
UniqueConstraintModel.objects.create(
|
||||||
|
race_name='other',
|
||||||
|
position=1,
|
||||||
|
global_id=3
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
serializer = UniqueConstraintSerializer()
|
||||||
|
# the order of validators isn't deterministic so delete
|
||||||
|
# fancy_conditions field that has two of them
|
||||||
|
del serializer.fields['fancy_conditions']
|
||||||
|
expected = dedent("""
|
||||||
|
UniqueConstraintSerializer():
|
||||||
|
id = IntegerField(label='ID', read_only=True)
|
||||||
|
race_name = CharField(max_length=100, required=True)
|
||||||
|
position = IntegerField(required=True)
|
||||||
|
global_id = IntegerField(validators=[<UniqueValidator(queryset=UniqueConstraintModel.objects.all())>])
|
||||||
|
class Meta:
|
||||||
|
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
|
||||||
|
""")
|
||||||
|
assert repr(serializer) == expected
|
||||||
|
|
||||||
|
def test_unique_together_field(self):
|
||||||
|
"""
|
||||||
|
UniqueConstraint fields and condition attributes must be passed
|
||||||
|
to UniqueTogetherValidator as fields and queryset
|
||||||
|
"""
|
||||||
|
serializer = UniqueConstraintSerializer()
|
||||||
|
assert len(serializer.validators) == 1
|
||||||
|
validator = serializer.validators[0]
|
||||||
|
assert validator.fields == ('race_name', 'position')
|
||||||
|
assert set(validator.queryset.values_list(flat=True)) == set(
|
||||||
|
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_field_uniq_validators(self):
|
||||||
|
"""
|
||||||
|
UniqueConstraint with single field must be transformed into
|
||||||
|
field's UniqueValidator
|
||||||
|
"""
|
||||||
|
serializer = UniqueConstraintSerializer()
|
||||||
|
assert len(serializer.validators) == 1
|
||||||
|
validators = serializer.fields['global_id'].validators
|
||||||
|
assert len(validators) == 1
|
||||||
|
assert validators[0].queryset == UniqueConstraintModel.objects
|
||||||
|
|
||||||
|
validators = serializer.fields['fancy_conditions'].validators
|
||||||
|
assert len(validators) == 2
|
||||||
|
ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators}
|
||||||
|
assert ids_in_qs == {frozenset([1]), frozenset([3])}
|
||||||
|
|
||||||
|
|
||||||
# Tests for `UniqueForDateValidator`
|
# Tests for `UniqueForDateValidator`
|
||||||
# ----------------------------------
|
# ----------------------------------
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user