mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-31 07:57:55 +03:00 
			
		
		
		
	Fix unique together validator doesn't respect condition's fields (#9360)
This commit is contained in:
		
							parent
							
								
									f30c0e2eed
								
							
						
					
					
						commit
						17e95604f5
					
				|  | @ -3,6 +3,9 @@ The `compat` module provides support for backwards compatibility with older | |||
| versions of Django/Python, and compatibility wrappers around optional packages. | ||||
| """ | ||||
| import django | ||||
| from django.db import models | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
| from django.db.models.sql.query import Node | ||||
| from django.views.generic import View | ||||
| 
 | ||||
| 
 | ||||
|  | @ -157,6 +160,10 @@ if django.VERSION >= (5, 1): | |||
|     #       1) the list of validators and 2) the error message. Starting from | ||||
|     #       Django 5.1 ip_address_validators only returns the list of validators | ||||
|     from django.core.validators import ip_address_validators | ||||
| 
 | ||||
|     def get_referenced_base_fields_from_q(q): | ||||
|         return q.referenced_base_fields | ||||
| 
 | ||||
| else: | ||||
|     # Django <= 5.1: create a compatibility shim for ip_address_validators | ||||
|     from django.core.validators import \ | ||||
|  | @ -165,6 +172,35 @@ else: | |||
|     def ip_address_validators(protocol, unpack_ipv4): | ||||
|         return _ip_address_validators(protocol, unpack_ipv4)[0] | ||||
| 
 | ||||
|     # Django < 5.1: create a compatibility shim for Q.referenced_base_fields | ||||
|     # https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179 | ||||
|     def _get_paths_from_expression(expr): | ||||
|         if isinstance(expr, models.F): | ||||
|             yield expr.name | ||||
|         elif hasattr(expr, 'flatten'): | ||||
|             for child in expr.flatten(): | ||||
|                 if isinstance(child, models.F): | ||||
|                     yield child.name | ||||
|                 elif isinstance(child, models.Q): | ||||
|                     yield from _get_children_from_q(child) | ||||
| 
 | ||||
|     def _get_children_from_q(q): | ||||
|         for child in q.children: | ||||
|             if isinstance(child, Node): | ||||
|                 yield from _get_children_from_q(child) | ||||
|             elif isinstance(child, tuple): | ||||
|                 lhs, rhs = child | ||||
|                 yield lhs | ||||
|                 if hasattr(rhs, 'resolve_expression'): | ||||
|                     yield from _get_paths_from_expression(rhs) | ||||
|             elif hasattr(child, 'resolve_expression'): | ||||
|                 yield from _get_paths_from_expression(child) | ||||
| 
 | ||||
|     def get_referenced_base_fields_from_q(q): | ||||
|         return { | ||||
|             child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q) | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| # `separators` argument to `json.dumps()` differs between 2.x and 3.x | ||||
| # See: https://bugs.python.org/issue22767 | ||||
|  |  | |||
|  | @ -26,7 +26,9 @@ from django.utils import timezone | |||
| from django.utils.functional import cached_property | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| 
 | ||||
| from rest_framework.compat import postgres_fields | ||||
| from rest_framework.compat import ( | ||||
|     get_referenced_base_fields_from_q, postgres_fields | ||||
| ) | ||||
| from rest_framework.exceptions import ErrorDetail, ValidationError | ||||
| from rest_framework.fields import get_error_detail | ||||
| from rest_framework.settings import api_settings | ||||
|  | @ -1425,20 +1427,20 @@ class ModelSerializer(Serializer): | |||
| 
 | ||||
|     def get_unique_together_constraints(self, model): | ||||
|         """ | ||||
|         Returns iterator of (fields, queryset), each entry describes an unique together | ||||
|         constraint on `fields` in `queryset`. | ||||
|         Returns iterator of (fields, queryset, condition_fields, condition), | ||||
|         each entry describes an unique together constraint on `fields` in `queryset` | ||||
|         with respect of constraint's `condition`. | ||||
|         """ | ||||
|         for parent_class in [model] + list(model._meta.parents): | ||||
|             for unique_together in parent_class._meta.unique_together: | ||||
|                 yield unique_together, model._default_manager | ||||
|                 yield unique_together, model._default_manager, [], None | ||||
|             for constraint in parent_class._meta.constraints: | ||||
|                 if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: | ||||
|                     yield ( | ||||
|                         constraint.fields, | ||||
|                         model._default_manager | ||||
|                         if constraint.condition is None | ||||
|                         else model._default_manager.filter(constraint.condition) | ||||
|                     ) | ||||
|                     if constraint.condition is None: | ||||
|                         condition_fields = [] | ||||
|                     else: | ||||
|                         condition_fields = list(get_referenced_base_fields_from_q(constraint.condition)) | ||||
|                     yield (constraint.fields, model._default_manager, condition_fields, constraint.condition) | ||||
| 
 | ||||
|     def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): | ||||
|         """ | ||||
|  | @ -1470,9 +1472,10 @@ class ModelSerializer(Serializer): | |||
| 
 | ||||
|         # Include each of the `unique_together` and `UniqueConstraint` field names, | ||||
|         # so long as all the field names are included on the serializer. | ||||
|         for unique_together_list, queryset in self.get_unique_together_constraints(model): | ||||
|             if set(field_names).issuperset(unique_together_list): | ||||
|                 unique_constraint_names |= set(unique_together_list) | ||||
|         for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model): | ||||
|             unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields) | ||||
|             if set(field_names).issuperset(unique_together_list_and_condition_fields): | ||||
|                 unique_constraint_names |= unique_together_list_and_condition_fields | ||||
| 
 | ||||
|         # Now we have all the field names that have uniqueness constraints | ||||
|         # applied, we can add the extra 'required=...' or 'default=...' | ||||
|  | @ -1594,12 +1597,13 @@ class ModelSerializer(Serializer): | |||
|         # Note that we make sure to check `unique_together` both on the | ||||
|         # base model class, but also on any parent classes. | ||||
|         validators = [] | ||||
|         for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): | ||||
|         for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model): | ||||
|             # Skip if serializer does not map to all unique together sources | ||||
|             if not set(source_map).issuperset(unique_together): | ||||
|             unique_together_and_condition_fields = set(unique_together) | set(condition_fields) | ||||
|             if not set(source_map).issuperset(unique_together_and_condition_fields): | ||||
|                 continue | ||||
| 
 | ||||
|             for source in unique_together: | ||||
|             for source in unique_together_and_condition_fields: | ||||
|                 assert len(source_map[source]) == 1, ( | ||||
|                     "Unable to create `UniqueTogetherValidator` for " | ||||
|                     "`{model}.{field}` as `{serializer}` has multiple " | ||||
|  | @ -1618,7 +1622,9 @@ class ModelSerializer(Serializer): | |||
|             field_names = tuple(source_map[f][0] for f in unique_together) | ||||
|             validator = UniqueTogetherValidator( | ||||
|                 queryset=queryset, | ||||
|                 fields=field_names | ||||
|                 fields=field_names, | ||||
|                 condition_fields=tuple(source_map[f][0] for f in condition_fields), | ||||
|                 condition=condition, | ||||
|             ) | ||||
|             validators.append(validator) | ||||
|         return validators | ||||
|  |  | |||
|  | @ -6,7 +6,9 @@ This gives us better separation of concerns, allows us to use single-step | |||
| object creation, and makes it possible to switch between using the implicit | ||||
| `ModelSerializer` class and an equivalent explicit `Serializer` class. | ||||
| """ | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db import DataError | ||||
| from django.db.models import Exists | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| 
 | ||||
| from rest_framework.exceptions import ValidationError | ||||
|  | @ -23,6 +25,17 @@ def qs_exists(queryset): | |||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def qs_exists_with_condition(queryset, condition, against): | ||||
|     if condition is None: | ||||
|         return qs_exists(queryset) | ||||
|     try: | ||||
|         # use the same query as UniqueConstraint.validate | ||||
|         # https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672 | ||||
|         return (condition & Exists(queryset.filter(condition))).check(against) | ||||
|     except (TypeError, ValueError, DataError, FieldError): | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def qs_filter(queryset, **kwargs): | ||||
|     try: | ||||
|         return queryset.filter(**kwargs) | ||||
|  | @ -99,10 +112,12 @@ class UniqueTogetherValidator: | |||
|     missing_message = _('This field is required.') | ||||
|     requires_context = True | ||||
| 
 | ||||
|     def __init__(self, queryset, fields, message=None): | ||||
|     def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None): | ||||
|         self.queryset = queryset | ||||
|         self.fields = fields | ||||
|         self.message = message or self.message | ||||
|         self.condition_fields = [] if condition_fields is None else condition_fields | ||||
|         self.condition = condition | ||||
| 
 | ||||
|     def enforce_required_fields(self, attrs, serializer): | ||||
|         """ | ||||
|  | @ -114,7 +129,7 @@ class UniqueTogetherValidator: | |||
| 
 | ||||
|         missing_items = { | ||||
|             field_name: self.missing_message | ||||
|             for field_name in self.fields | ||||
|             for field_name in (*self.fields, *self.condition_fields) | ||||
|             if serializer.fields[field_name].source not in attrs | ||||
|         } | ||||
|         if missing_items: | ||||
|  | @ -173,16 +188,19 @@ class UniqueTogetherValidator: | |||
|                 if attrs[field_name] != getattr(serializer.instance, field_name) | ||||
|             ] | ||||
| 
 | ||||
|         if checked_values and None not in checked_values and qs_exists(queryset): | ||||
|         condition_kwargs = {source: attrs[source] for source in self.condition_fields} | ||||
|         if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs): | ||||
|             field_names = ', '.join(self.fields) | ||||
|             message = self.message.format(field_names=field_names) | ||||
|             raise ValidationError(message, code='unique') | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|         return '<%s(queryset=%s, fields=%s)>' % ( | ||||
|         return '<{}({})>'.format( | ||||
|             self.__class__.__name__, | ||||
|             smart_repr(self.queryset), | ||||
|             smart_repr(self.fields) | ||||
|             ', '.join( | ||||
|                 f'{attr}={smart_repr(getattr(self, attr))}' | ||||
|                 for attr in ('queryset', 'fields', 'condition') | ||||
|                 if getattr(self, attr) is not None) | ||||
|         ) | ||||
| 
 | ||||
|     def __eq__(self, other): | ||||
|  |  | |||
|  | @ -521,7 +521,7 @@ class UniqueConstraintModel(models.Model): | |||
|     race_name = models.CharField(max_length=100) | ||||
|     position = models.IntegerField() | ||||
|     global_id = models.IntegerField() | ||||
|     fancy_conditions = models.IntegerField(null=True) | ||||
|     fancy_conditions = models.IntegerField() | ||||
| 
 | ||||
|     class Meta: | ||||
|         constraints = [ | ||||
|  | @ -543,7 +543,12 @@ class UniqueConstraintModel(models.Model): | |||
|                 name="unique_constraint_model_together_uniq", | ||||
|                 fields=('race_name', 'position'), | ||||
|                 condition=models.Q(race_name='example'), | ||||
|             ) | ||||
|             ), | ||||
|             models.UniqueConstraint( | ||||
|                 name='unique_constraint_model_together_uniq2', | ||||
|                 fields=('race_name', 'position'), | ||||
|                 condition=models.Q(fancy_conditions__gte=10), | ||||
|             ), | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -576,17 +581,20 @@ class TestUniqueConstraintValidation(TestCase): | |||
|         self.instance = UniqueConstraintModel.objects.create( | ||||
|             race_name='example', | ||||
|             position=1, | ||||
|             global_id=1 | ||||
|             global_id=1, | ||||
|             fancy_conditions=1 | ||||
|         ) | ||||
|         UniqueConstraintModel.objects.create( | ||||
|             race_name='example', | ||||
|             position=2, | ||||
|             global_id=2 | ||||
|             global_id=2, | ||||
|             fancy_conditions=1 | ||||
|         ) | ||||
|         UniqueConstraintModel.objects.create( | ||||
|             race_name='other', | ||||
|             position=1, | ||||
|             global_id=3 | ||||
|             global_id=3, | ||||
|             fancy_conditions=1 | ||||
|         ) | ||||
| 
 | ||||
|     def test_repr(self): | ||||
|  | @ -601,22 +609,55 @@ class TestUniqueConstraintValidation(TestCase): | |||
|                 position = IntegerField\(.*required=True\) | ||||
|                 global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\) | ||||
|                 class Meta: | ||||
|                     validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\] | ||||
|                     validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\] | ||||
|         """) | ||||
|         assert re.search(expected, repr(serializer)) is not None | ||||
| 
 | ||||
|     def test_unique_together_field(self): | ||||
|     def test_unique_together_condition(self): | ||||
|         """ | ||||
|         UniqueConstraint fields and condition attributes must be passed | ||||
|         to UniqueTogetherValidator as fields and queryset | ||||
|         Fields used in UniqueConstraint's condition must be included | ||||
|         into queryset existence check | ||||
|         """ | ||||
|         serializer = UniqueConstraintSerializer() | ||||
|         assert len(serializer.validators) == 1 | ||||
|         validator = serializer.validators[0] | ||||
|         assert validator.fields == ('race_name', 'position') | ||||
|         assert set(validator.queryset.values_list(flat=True)) == set( | ||||
|             UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) | ||||
|         UniqueConstraintModel.objects.create( | ||||
|             race_name='condition', | ||||
|             position=1, | ||||
|             global_id=10, | ||||
|             fancy_conditions=10, | ||||
|         ) | ||||
|         serializer = UniqueConstraintSerializer(data={ | ||||
|             'race_name': 'condition', | ||||
|             'position': 1, | ||||
|             'global_id': 11, | ||||
|             'fancy_conditions': 9, | ||||
|         }) | ||||
|         assert serializer.is_valid() | ||||
|         serializer = UniqueConstraintSerializer(data={ | ||||
|             'race_name': 'condition', | ||||
|             'position': 1, | ||||
|             'global_id': 11, | ||||
|             'fancy_conditions': 11, | ||||
|         }) | ||||
|         assert not serializer.is_valid() | ||||
| 
 | ||||
|     def test_unique_together_condition_fields_required(self): | ||||
|         """ | ||||
|         Fields used in UniqueConstraint's condition must be present in serializer | ||||
|         """ | ||||
|         serializer = UniqueConstraintSerializer(data={ | ||||
|             'race_name': 'condition', | ||||
|             'position': 1, | ||||
|             'global_id': 11, | ||||
|         }) | ||||
|         assert not serializer.is_valid() | ||||
|         assert serializer.errors == {'fancy_conditions': ['This field is required.']} | ||||
| 
 | ||||
|         class NoFieldsSerializer(serializers.ModelSerializer): | ||||
|             class Meta: | ||||
|                 model = UniqueConstraintModel | ||||
|                 fields = ('race_name', 'position', 'global_id') | ||||
| 
 | ||||
|         serializer = NoFieldsSerializer() | ||||
|         assert len(serializer.validators) == 1 | ||||
| 
 | ||||
|     def test_single_field_uniq_validators(self): | ||||
|         """ | ||||
|  | @ -625,9 +666,8 @@ class TestUniqueConstraintValidation(TestCase): | |||
|         """ | ||||
|         # Django 5 includes Max and Min values validators for IntergerField | ||||
|         extra_validators_qty = 2 if django_version[0] >= 5 else 0 | ||||
|         # | ||||
|         serializer = UniqueConstraintSerializer() | ||||
|         assert len(serializer.validators) == 1 | ||||
|         assert len(serializer.validators) == 2 | ||||
|         validators = serializer.fields['global_id'].validators | ||||
|         assert len(validators) == 1 + extra_validators_qty | ||||
|         assert validators[0].queryset == UniqueConstraintModel.objects | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user