mirror of
https://github.com/encode/django-rest-framework.git
synced 2026-02-17 20:50:35 +03:00
Merge fa13a1e3a6 into e45518a132
This commit is contained in:
commit
e7f5ba00f0
|
|
@ -1435,20 +1435,26 @@ class ModelSerializer(Serializer):
|
|||
|
||||
def get_unique_together_constraints(self, model):
|
||||
"""
|
||||
Returns iterator of (fields, queryset, condition_fields, condition),
|
||||
Returns iterator of (fields, queryset, condition_fields, condition, nulls_distinct),
|
||||
each entry describes an unique together constraint on `fields` in `queryset`
|
||||
with respect of constraint's `condition`.
|
||||
with respect of constraint's `condition` and `nulls_distinct` option.
|
||||
"""
|
||||
for parent_class in [model] + list(model._meta.parents):
|
||||
for unique_together in parent_class._meta.unique_together:
|
||||
yield unique_together, model._default_manager, [], None
|
||||
yield unique_together, model._default_manager, [], None, None
|
||||
for constraint in parent_class._meta.constraints:
|
||||
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
|
||||
if constraint.condition is None:
|
||||
condition_fields = []
|
||||
else:
|
||||
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
|
||||
yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
|
||||
yield (
|
||||
constraint.fields,
|
||||
model._default_manager,
|
||||
condition_fields,
|
||||
constraint.condition,
|
||||
getattr(constraint, 'nulls_distinct', None),
|
||||
)
|
||||
|
||||
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
|
||||
"""
|
||||
|
|
@ -1481,7 +1487,7 @@ class ModelSerializer(Serializer):
|
|||
|
||||
# Include each of the `unique_together` and `UniqueConstraint` field names,
|
||||
# so long as all the field names are included on the serializer.
|
||||
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
|
||||
for unique_together_list, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(model):
|
||||
unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields)
|
||||
if model_fields_names.issuperset(unique_together_list_and_condition_fields):
|
||||
unique_constraint_names |= unique_together_list_and_condition_fields
|
||||
|
|
@ -1624,7 +1630,7 @@ class ModelSerializer(Serializer):
|
|||
# Note that we make sure to check `unique_together` both on the
|
||||
# base model class, but also on any parent classes.
|
||||
validators = []
|
||||
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
|
||||
for unique_together, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(self.Meta.model):
|
||||
# Skip if serializer does not map to all unique together sources
|
||||
unique_together_and_condition_fields = set(unique_together) | set(condition_fields)
|
||||
if not set(source_map).issuperset(unique_together_and_condition_fields):
|
||||
|
|
@ -1658,6 +1664,7 @@ class ModelSerializer(Serializer):
|
|||
condition=condition,
|
||||
message=violation_error_message,
|
||||
code=getattr(constraint, 'violation_error_code', None),
|
||||
nulls_distinct=nulls_distinct,
|
||||
)
|
||||
validators.append(validator)
|
||||
return validators
|
||||
|
|
|
|||
|
|
@ -113,13 +113,14 @@ class UniqueTogetherValidator:
|
|||
requires_context = True
|
||||
code = 'unique'
|
||||
|
||||
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None):
|
||||
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None, nulls_distinct=None):
|
||||
self.queryset = queryset
|
||||
self.fields = fields
|
||||
self.message = message or self.message
|
||||
self.condition_fields = [] if condition_fields is None else condition_fields
|
||||
self.condition = condition
|
||||
self.code = code or self.code
|
||||
self.nulls_distinct = nulls_distinct
|
||||
|
||||
def enforce_required_fields(self, attrs, serializer):
|
||||
"""
|
||||
|
|
@ -197,17 +198,21 @@ class UniqueTogetherValidator:
|
|||
else getattr(serializer.instance, source)
|
||||
for source in condition_sources
|
||||
}
|
||||
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
|
||||
field_names = ', '.join(self.fields)
|
||||
message = self.message.format(field_names=field_names)
|
||||
raise ValidationError(message, code=self.code)
|
||||
if checked_values:
|
||||
# Skip validation for None values unless nulls_distinct is False
|
||||
if self.nulls_distinct is not False and None in checked_values:
|
||||
return
|
||||
if qs_exists_with_condition(queryset, self.condition, condition_kwargs):
|
||||
field_names = ', '.join(self.fields)
|
||||
message = self.message.format(field_names=field_names)
|
||||
raise ValidationError(message, code=self.code)
|
||||
|
||||
def __repr__(self):
|
||||
return '<{}({})>'.format(
|
||||
self.__class__.__name__,
|
||||
', '.join(
|
||||
f'{attr}={smart_repr(getattr(self, attr))}'
|
||||
for attr in ('queryset', 'fields', 'condition')
|
||||
for attr in ('queryset', 'fields', 'condition', 'nulls_distinct')
|
||||
if getattr(self, attr) is not None)
|
||||
)
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ class UniqueTogetherValidator:
|
|||
and self.queryset == other.queryset
|
||||
and self.fields == other.fields
|
||||
and self.code == other.code
|
||||
and self.nulls_distinct == other.nulls_distinct
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -616,6 +616,23 @@ class UniqueConstraintNullableModel(models.Model):
|
|||
]
|
||||
|
||||
|
||||
# Only define nulls_distinct model for Django 5.0+
|
||||
if django_version >= (5, 0):
|
||||
class UniqueConstraintNullsDistinctModel(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
code = models.CharField(max_length=100, null=True)
|
||||
category = models.CharField(max_length=100, null=True)
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
name='unique_code_category_nulls_not_distinct',
|
||||
fields=('code', 'category'),
|
||||
nulls_distinct=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UniqueConstraintCustomMessageCodeModel(models.Model):
|
||||
username = models.CharField(max_length=32)
|
||||
company_id = models.IntegerField()
|
||||
|
|
@ -1063,3 +1080,118 @@ class ValidatorsTests(TestCase):
|
|||
assert validator == validator2
|
||||
validator2.date_field = "bar2"
|
||||
assert validator != validator2
|
||||
|
||||
|
||||
# Tests for `nulls_distinct` option (Django 5.0+)
|
||||
# -----------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(
|
||||
django_version < (5, 0),
|
||||
reason="nulls_distinct requires Django 5.0+"
|
||||
)
|
||||
class TestUniqueConstraintNullsDistinct(TestCase):
|
||||
"""
|
||||
Tests for UniqueConstraint with nulls_distinct=False option.
|
||||
When nulls_distinct=False, NULL values should be treated as equal
|
||||
for uniqueness validation.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
from tests.test_validators import UniqueConstraintNullsDistinctModel
|
||||
|
||||
class UniqueConstraintNullsDistinctSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = UniqueConstraintNullsDistinctModel
|
||||
fields = ('name', 'code', 'category')
|
||||
|
||||
self.serializer_class = UniqueConstraintNullsDistinctSerializer
|
||||
|
||||
def test_nulls_distinct_false_validates_null_as_duplicate(self):
|
||||
"""
|
||||
When nulls_distinct=False, creating a second record with NULL values
|
||||
in the constrained fields should fail validation.
|
||||
"""
|
||||
from tests.test_validators import UniqueConstraintNullsDistinctModel
|
||||
|
||||
# Create first record with NULL values
|
||||
UniqueConstraintNullsDistinctModel.objects.create(
|
||||
name='First',
|
||||
code=None,
|
||||
category=None
|
||||
)
|
||||
|
||||
# Attempt to create second record with same NULL values
|
||||
serializer = self.serializer_class(data={
|
||||
'name': 'Second',
|
||||
'code': None,
|
||||
'category': None
|
||||
})
|
||||
|
||||
# Should fail validation because nulls_distinct=False
|
||||
assert not serializer.is_valid()
|
||||
|
||||
def test_nulls_distinct_false_allows_different_non_null_values(self):
|
||||
"""
|
||||
Non-NULL values should still work normally with uniqueness validation.
|
||||
"""
|
||||
from tests.test_validators import UniqueConstraintNullsDistinctModel
|
||||
|
||||
# Create first record with non-NULL values
|
||||
UniqueConstraintNullsDistinctModel.objects.create(
|
||||
name='First',
|
||||
code='A',
|
||||
category='X'
|
||||
)
|
||||
|
||||
# Create second record with different values - should pass
|
||||
serializer = self.serializer_class(data={
|
||||
'name': 'Second',
|
||||
'code': 'B',
|
||||
'category': 'Y'
|
||||
})
|
||||
assert serializer.is_valid(), serializer.errors
|
||||
|
||||
def test_nulls_distinct_false_rejects_duplicate_non_null_values(self):
|
||||
"""
|
||||
Duplicate non-NULL values should still fail validation.
|
||||
"""
|
||||
from tests.test_validators import UniqueConstraintNullsDistinctModel
|
||||
|
||||
# Create first record
|
||||
UniqueConstraintNullsDistinctModel.objects.create(
|
||||
name='First',
|
||||
code='A',
|
||||
category='X'
|
||||
)
|
||||
|
||||
# Attempt to create duplicate - should fail
|
||||
serializer = self.serializer_class(data={
|
||||
'name': 'Second',
|
||||
'code': 'A',
|
||||
'category': 'X'
|
||||
})
|
||||
assert not serializer.is_valid()
|
||||
|
||||
def test_unique_together_validator_nulls_distinct_equality(self):
|
||||
"""
|
||||
Test that UniqueTogetherValidator equality considers nulls_distinct.
|
||||
"""
|
||||
mock_queryset = MagicMock()
|
||||
validator1 = UniqueTogetherValidator(
|
||||
queryset=mock_queryset,
|
||||
fields=('a', 'b'),
|
||||
nulls_distinct=False
|
||||
)
|
||||
validator2 = UniqueTogetherValidator(
|
||||
queryset=mock_queryset,
|
||||
fields=('a', 'b'),
|
||||
nulls_distinct=False
|
||||
)
|
||||
validator3 = UniqueTogetherValidator(
|
||||
queryset=mock_queryset,
|
||||
fields=('a', 'b'),
|
||||
nulls_distinct=True
|
||||
)
|
||||
|
||||
assert validator1 == validator2
|
||||
assert validator1 != validator3
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user