mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-25 13:11:26 +03:00 
			
		
		
		
	Drop set_context() (#7062)
				
					
				
			* Do not persist the context in validators Fixes encode/django-rest-framework#5760 * Drop set_context() in favour of 'requires_context = True'
This commit is contained in:
		
							parent
							
								
									9325c3f654
								
							
						
					
					
						commit
						070cff5a03
					
				|  | @ -50,7 +50,19 @@ If set, this gives the default value that will be used for the field if no input | ||||||
| 
 | 
 | ||||||
| The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned. | The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned. | ||||||
| 
 | 
 | ||||||
| May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context). | May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `requires_context = True` attribute, then the serializer field will be passed as an argument. | ||||||
|  | 
 | ||||||
|  | For example: | ||||||
|  | 
 | ||||||
|  |     class CurrentUserDefault: | ||||||
|  |         """ | ||||||
|  |         May be applied as a `default=...` value on a serializer field. | ||||||
|  |         Returns the current user. | ||||||
|  |         """ | ||||||
|  |         requires_context = True | ||||||
|  | 
 | ||||||
|  |         def __call__(self, serializer_field): | ||||||
|  |             return serializer_field.context['request'].user | ||||||
| 
 | 
 | ||||||
| When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance. | When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance. | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -291,13 +291,17 @@ To write a class-based validator, use the `__call__` method. Class-based validat | ||||||
|                 message = 'This field must be a multiple of %d.' % self.base |                 message = 'This field must be a multiple of %d.' % self.base | ||||||
|                 raise serializers.ValidationError(message) |                 raise serializers.ValidationError(message) | ||||||
| 
 | 
 | ||||||
| #### Using `set_context()` | #### Accessing the context | ||||||
| 
 | 
 | ||||||
| In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator. | In some advanced cases you might want a validator to be passed the serializer | ||||||
|  | field it is being used with as additional context. You can do so by setting | ||||||
|  | a `requires_context = True` attribute on the validator. The `__call__` method | ||||||
|  | will then be called with the `serializer_field` | ||||||
|  | or `serializer` as an additional argument. | ||||||
| 
 | 
 | ||||||
|     def set_context(self, serializer_field): |     requires_context = True | ||||||
|         # Determine if this is an update or a create operation. | 
 | ||||||
|         # In `__call__` we can then use that information to modify the validation behavior. |     def __call__(self, value, serializer_field): | ||||||
|         self.is_update = serializer_field.parent.instance is not None |         ... | ||||||
| 
 | 
 | ||||||
| [cite]: https://docs.djangoproject.com/en/stable/ref/validators/ | [cite]: https://docs.djangoproject.com/en/stable/ref/validators/ | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import functools | ||||||
| import inspect | import inspect | ||||||
| import re | import re | ||||||
| import uuid | import uuid | ||||||
|  | import warnings | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||||
| 
 | 
 | ||||||
|  | @ -249,18 +250,29 @@ class CreateOnlyDefault: | ||||||
|     for create operations, but that do not return any value for update |     for create operations, but that do not return any value for update | ||||||
|     operations. |     operations. | ||||||
|     """ |     """ | ||||||
|  |     requires_context = True | ||||||
|  | 
 | ||||||
|     def __init__(self, default): |     def __init__(self, default): | ||||||
|         self.default = default |         self.default = default | ||||||
| 
 | 
 | ||||||
|     def set_context(self, serializer_field): |     def __call__(self, serializer_field): | ||||||
|         self.is_update = serializer_field.parent.instance is not None |         is_update = serializer_field.parent.instance is not None | ||||||
|         if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update: |         if is_update: | ||||||
|             self.default.set_context(serializer_field) |  | ||||||
| 
 |  | ||||||
|     def __call__(self): |  | ||||||
|         if self.is_update: |  | ||||||
|             raise SkipField() |             raise SkipField() | ||||||
|         if callable(self.default): |         if callable(self.default): | ||||||
|  |             if hasattr(self.default, 'set_context'): | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "Method `set_context` on defaults is deprecated and will " | ||||||
|  |                     "no longer be called starting with 3.12. Instead set " | ||||||
|  |                     "`requires_context = True` on the class, and accept the " | ||||||
|  |                     "context as an additional argument.", | ||||||
|  |                     DeprecationWarning, stacklevel=2 | ||||||
|  |                 ) | ||||||
|  |                 self.default.set_context(self) | ||||||
|  | 
 | ||||||
|  |             if getattr(self.default, 'requires_context', False): | ||||||
|  |                 return self.default(serializer_field) | ||||||
|  |             else: | ||||||
|                 return self.default() |                 return self.default() | ||||||
|         return self.default |         return self.default | ||||||
| 
 | 
 | ||||||
|  | @ -269,11 +281,10 @@ class CreateOnlyDefault: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CurrentUserDefault: | class CurrentUserDefault: | ||||||
|     def set_context(self, serializer_field): |     requires_context = True | ||||||
|         self.user = serializer_field.context['request'].user |  | ||||||
| 
 | 
 | ||||||
|     def __call__(self): |     def __call__(self, serializer_field): | ||||||
|         return self.user |         return serializer_field.context['request'].user | ||||||
| 
 | 
 | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return '%s()' % self.__class__.__name__ |         return '%s()' % self.__class__.__name__ | ||||||
|  | @ -489,8 +500,20 @@ class Field: | ||||||
|             raise SkipField() |             raise SkipField() | ||||||
|         if callable(self.default): |         if callable(self.default): | ||||||
|             if hasattr(self.default, 'set_context'): |             if hasattr(self.default, 'set_context'): | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "Method `set_context` on defaults is deprecated and will " | ||||||
|  |                     "no longer be called starting with 3.12. Instead set " | ||||||
|  |                     "`requires_context = True` on the class, and accept the " | ||||||
|  |                     "context as an additional argument.", | ||||||
|  |                     DeprecationWarning, stacklevel=2 | ||||||
|  |                 ) | ||||||
|                 self.default.set_context(self) |                 self.default.set_context(self) | ||||||
|  | 
 | ||||||
|  |             if getattr(self.default, 'requires_context', False): | ||||||
|  |                 return self.default(self) | ||||||
|  |             else: | ||||||
|                 return self.default() |                 return self.default() | ||||||
|  | 
 | ||||||
|         return self.default |         return self.default | ||||||
| 
 | 
 | ||||||
|     def validate_empty_values(self, data): |     def validate_empty_values(self, data): | ||||||
|  | @ -551,9 +574,19 @@ class Field: | ||||||
|         errors = [] |         errors = [] | ||||||
|         for validator in self.validators: |         for validator in self.validators: | ||||||
|             if hasattr(validator, 'set_context'): |             if hasattr(validator, 'set_context'): | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "Method `set_context` on validators is deprecated and will " | ||||||
|  |                     "no longer be called starting with 3.12. Instead set " | ||||||
|  |                     "`requires_context = True` on the class, and accept the " | ||||||
|  |                     "context as an additional argument.", | ||||||
|  |                     DeprecationWarning, stacklevel=2 | ||||||
|  |                 ) | ||||||
|                 validator.set_context(self) |                 validator.set_context(self) | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|  |                 if getattr(validator, 'requires_context', False): | ||||||
|  |                     validator(value, self) | ||||||
|  |                 else: | ||||||
|                     validator(value) |                     validator(value) | ||||||
|             except ValidationError as exc: |             except ValidationError as exc: | ||||||
|                 # If the validation error contains a mapping of fields to |                 # If the validation error contains a mapping of fields to | ||||||
|  |  | ||||||
|  | @ -37,6 +37,7 @@ class UniqueValidator: | ||||||
|     Should be applied to an individual field on the serializer. |     Should be applied to an individual field on the serializer. | ||||||
|     """ |     """ | ||||||
|     message = _('This field must be unique.') |     message = _('This field must be unique.') | ||||||
|  |     requires_context = True | ||||||
| 
 | 
 | ||||||
|     def __init__(self, queryset, message=None, lookup='exact'): |     def __init__(self, queryset, message=None, lookup='exact'): | ||||||
|         self.queryset = queryset |         self.queryset = queryset | ||||||
|  | @ -44,37 +45,32 @@ class UniqueValidator: | ||||||
|         self.message = message or self.message |         self.message = message or self.message | ||||||
|         self.lookup = lookup |         self.lookup = lookup | ||||||
| 
 | 
 | ||||||
|     def set_context(self, serializer_field): |     def filter_queryset(self, value, queryset, field_name): | ||||||
|         """ |  | ||||||
|         This hook is called by the serializer instance, |  | ||||||
|         prior to the validation call being made. |  | ||||||
|         """ |  | ||||||
|         # Determine the underlying model field name. This may not be the |  | ||||||
|         # same as the serializer field name if `source=<>` is set. |  | ||||||
|         self.field_name = serializer_field.source_attrs[-1] |  | ||||||
|         # Determine the existing instance, if this is an update operation. |  | ||||||
|         self.instance = getattr(serializer_field.parent, 'instance', None) |  | ||||||
| 
 |  | ||||||
|     def filter_queryset(self, value, queryset): |  | ||||||
|         """ |         """ | ||||||
|         Filter the queryset to all instances matching the given attribute. |         Filter the queryset to all instances matching the given attribute. | ||||||
|         """ |         """ | ||||||
|         filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} |         filter_kwargs = {'%s__%s' % (field_name, self.lookup): value} | ||||||
|         return qs_filter(queryset, **filter_kwargs) |         return qs_filter(queryset, **filter_kwargs) | ||||||
| 
 | 
 | ||||||
|     def exclude_current_instance(self, queryset): |     def exclude_current_instance(self, queryset, instance): | ||||||
|         """ |         """ | ||||||
|         If an instance is being updated, then do not include |         If an instance is being updated, then do not include | ||||||
|         that instance itself as a uniqueness conflict. |         that instance itself as a uniqueness conflict. | ||||||
|         """ |         """ | ||||||
|         if self.instance is not None: |         if instance is not None: | ||||||
|             return queryset.exclude(pk=self.instance.pk) |             return queryset.exclude(pk=instance.pk) | ||||||
|         return queryset |         return queryset | ||||||
| 
 | 
 | ||||||
|     def __call__(self, value): |     def __call__(self, value, serializer_field): | ||||||
|  |         # Determine the underlying model field name. This may not be the | ||||||
|  |         # same as the serializer field name if `source=<>` is set. | ||||||
|  |         field_name = serializer_field.source_attrs[-1] | ||||||
|  |         # Determine the existing instance, if this is an update operation. | ||||||
|  |         instance = getattr(serializer_field.parent, 'instance', None) | ||||||
|  | 
 | ||||||
|         queryset = self.queryset |         queryset = self.queryset | ||||||
|         queryset = self.filter_queryset(value, queryset) |         queryset = self.filter_queryset(value, queryset, field_name) | ||||||
|         queryset = self.exclude_current_instance(queryset) |         queryset = self.exclude_current_instance(queryset, instance) | ||||||
|         if qs_exists(queryset): |         if qs_exists(queryset): | ||||||
|             raise ValidationError(self.message, code='unique') |             raise ValidationError(self.message, code='unique') | ||||||
| 
 | 
 | ||||||
|  | @ -93,6 +89,7 @@ class UniqueTogetherValidator: | ||||||
|     """ |     """ | ||||||
|     message = _('The fields {field_names} must make a unique set.') |     message = _('The fields {field_names} must make a unique set.') | ||||||
|     missing_message = _('This field is required.') |     missing_message = _('This field is required.') | ||||||
|  |     requires_context = True | ||||||
| 
 | 
 | ||||||
|     def __init__(self, queryset, fields, message=None): |     def __init__(self, queryset, fields, message=None): | ||||||
|         self.queryset = queryset |         self.queryset = queryset | ||||||
|  | @ -100,20 +97,12 @@ class UniqueTogetherValidator: | ||||||
|         self.serializer_field = None |         self.serializer_field = None | ||||||
|         self.message = message or self.message |         self.message = message or self.message | ||||||
| 
 | 
 | ||||||
|     def set_context(self, serializer): |     def enforce_required_fields(self, attrs, instance): | ||||||
|         """ |  | ||||||
|         This hook is called by the serializer instance, |  | ||||||
|         prior to the validation call being made. |  | ||||||
|         """ |  | ||||||
|         # Determine the existing instance, if this is an update operation. |  | ||||||
|         self.instance = getattr(serializer, 'instance', None) |  | ||||||
| 
 |  | ||||||
|     def enforce_required_fields(self, attrs): |  | ||||||
|         """ |         """ | ||||||
|         The `UniqueTogetherValidator` always forces an implied 'required' |         The `UniqueTogetherValidator` always forces an implied 'required' | ||||||
|         state on the fields it applies to. |         state on the fields it applies to. | ||||||
|         """ |         """ | ||||||
|         if self.instance is not None: |         if instance is not None: | ||||||
|             return |             return | ||||||
| 
 | 
 | ||||||
|         missing_items = { |         missing_items = { | ||||||
|  | @ -124,16 +113,16 @@ class UniqueTogetherValidator: | ||||||
|         if missing_items: |         if missing_items: | ||||||
|             raise ValidationError(missing_items, code='required') |             raise ValidationError(missing_items, code='required') | ||||||
| 
 | 
 | ||||||
|     def filter_queryset(self, attrs, queryset): |     def filter_queryset(self, attrs, queryset, instance): | ||||||
|         """ |         """ | ||||||
|         Filter the queryset to all instances matching the given attributes. |         Filter the queryset to all instances matching the given attributes. | ||||||
|         """ |         """ | ||||||
|         # If this is an update, then any unprovided field should |         # If this is an update, then any unprovided field should | ||||||
|         # have it's value set based on the existing instance attribute. |         # have it's value set based on the existing instance attribute. | ||||||
|         if self.instance is not None: |         if instance is not None: | ||||||
|             for field_name in self.fields: |             for field_name in self.fields: | ||||||
|                 if field_name not in attrs: |                 if field_name not in attrs: | ||||||
|                     attrs[field_name] = getattr(self.instance, field_name) |                     attrs[field_name] = getattr(instance, field_name) | ||||||
| 
 | 
 | ||||||
|         # Determine the filter keyword arguments and filter the queryset. |         # Determine the filter keyword arguments and filter the queryset. | ||||||
|         filter_kwargs = { |         filter_kwargs = { | ||||||
|  | @ -142,20 +131,23 @@ class UniqueTogetherValidator: | ||||||
|         } |         } | ||||||
|         return qs_filter(queryset, **filter_kwargs) |         return qs_filter(queryset, **filter_kwargs) | ||||||
| 
 | 
 | ||||||
|     def exclude_current_instance(self, attrs, queryset): |     def exclude_current_instance(self, attrs, queryset, instance): | ||||||
|         """ |         """ | ||||||
|         If an instance is being updated, then do not include |         If an instance is being updated, then do not include | ||||||
|         that instance itself as a uniqueness conflict. |         that instance itself as a uniqueness conflict. | ||||||
|         """ |         """ | ||||||
|         if self.instance is not None: |         if instance is not None: | ||||||
|             return queryset.exclude(pk=self.instance.pk) |             return queryset.exclude(pk=instance.pk) | ||||||
|         return queryset |         return queryset | ||||||
| 
 | 
 | ||||||
|     def __call__(self, attrs): |     def __call__(self, attrs, serializer): | ||||||
|         self.enforce_required_fields(attrs) |         # Determine the existing instance, if this is an update operation. | ||||||
|  |         instance = getattr(serializer, 'instance', None) | ||||||
|  | 
 | ||||||
|  |         self.enforce_required_fields(attrs, instance) | ||||||
|         queryset = self.queryset |         queryset = self.queryset | ||||||
|         queryset = self.filter_queryset(attrs, queryset) |         queryset = self.filter_queryset(attrs, queryset, instance) | ||||||
|         queryset = self.exclude_current_instance(attrs, queryset) |         queryset = self.exclude_current_instance(attrs, queryset, instance) | ||||||
| 
 | 
 | ||||||
|         # Ignore validation if any field is None |         # Ignore validation if any field is None | ||||||
|         checked_values = [ |         checked_values = [ | ||||||
|  | @ -177,6 +169,7 @@ class UniqueTogetherValidator: | ||||||
| class BaseUniqueForValidator: | class BaseUniqueForValidator: | ||||||
|     message = None |     message = None | ||||||
|     missing_message = _('This field is required.') |     missing_message = _('This field is required.') | ||||||
|  |     requires_context = True | ||||||
| 
 | 
 | ||||||
|     def __init__(self, queryset, field, date_field, message=None): |     def __init__(self, queryset, field, date_field, message=None): | ||||||
|         self.queryset = queryset |         self.queryset = queryset | ||||||
|  | @ -184,18 +177,6 @@ class BaseUniqueForValidator: | ||||||
|         self.date_field = date_field |         self.date_field = date_field | ||||||
|         self.message = message or self.message |         self.message = message or self.message | ||||||
| 
 | 
 | ||||||
|     def set_context(self, serializer): |  | ||||||
|         """ |  | ||||||
|         This hook is called by the serializer instance, |  | ||||||
|         prior to the validation call being made. |  | ||||||
|         """ |  | ||||||
|         # Determine the underlying model field names. These may not be the |  | ||||||
|         # same as the serializer field names if `source=<>` is set. |  | ||||||
|         self.field_name = serializer.fields[self.field].source_attrs[-1] |  | ||||||
|         self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] |  | ||||||
|         # Determine the existing instance, if this is an update operation. |  | ||||||
|         self.instance = getattr(serializer, 'instance', None) |  | ||||||
| 
 |  | ||||||
|     def enforce_required_fields(self, attrs): |     def enforce_required_fields(self, attrs): | ||||||
|         """ |         """ | ||||||
|         The `UniqueFor<Range>Validator` classes always force an implied |         The `UniqueFor<Range>Validator` classes always force an implied | ||||||
|  | @ -209,23 +190,30 @@ class BaseUniqueForValidator: | ||||||
|         if missing_items: |         if missing_items: | ||||||
|             raise ValidationError(missing_items, code='required') |             raise ValidationError(missing_items, code='required') | ||||||
| 
 | 
 | ||||||
|     def filter_queryset(self, attrs, queryset): |     def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||||||
|         raise NotImplementedError('`filter_queryset` must be implemented.') |         raise NotImplementedError('`filter_queryset` must be implemented.') | ||||||
| 
 | 
 | ||||||
|     def exclude_current_instance(self, attrs, queryset): |     def exclude_current_instance(self, attrs, queryset, instance): | ||||||
|         """ |         """ | ||||||
|         If an instance is being updated, then do not include |         If an instance is being updated, then do not include | ||||||
|         that instance itself as a uniqueness conflict. |         that instance itself as a uniqueness conflict. | ||||||
|         """ |         """ | ||||||
|         if self.instance is not None: |         if instance is not None: | ||||||
|             return queryset.exclude(pk=self.instance.pk) |             return queryset.exclude(pk=instance.pk) | ||||||
|         return queryset |         return queryset | ||||||
| 
 | 
 | ||||||
|     def __call__(self, attrs): |     def __call__(self, attrs, serializer): | ||||||
|  |         # Determine the underlying model field names. These may not be the | ||||||
|  |         # same as the serializer field names if `source=<>` is set. | ||||||
|  |         field_name = serializer.fields[self.field].source_attrs[-1] | ||||||
|  |         date_field_name = serializer.fields[self.date_field].source_attrs[-1] | ||||||
|  |         # Determine the existing instance, if this is an update operation. | ||||||
|  |         instance = getattr(serializer, 'instance', None) | ||||||
|  | 
 | ||||||
|         self.enforce_required_fields(attrs) |         self.enforce_required_fields(attrs) | ||||||
|         queryset = self.queryset |         queryset = self.queryset | ||||||
|         queryset = self.filter_queryset(attrs, queryset) |         queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) | ||||||
|         queryset = self.exclude_current_instance(attrs, queryset) |         queryset = self.exclude_current_instance(attrs, queryset, instance) | ||||||
|         if qs_exists(queryset): |         if qs_exists(queryset): | ||||||
|             message = self.message.format(date_field=self.date_field) |             message = self.message.format(date_field=self.date_field) | ||||||
|             raise ValidationError({ |             raise ValidationError({ | ||||||
|  | @ -244,39 +232,39 @@ class BaseUniqueForValidator: | ||||||
| class UniqueForDateValidator(BaseUniqueForValidator): | class UniqueForDateValidator(BaseUniqueForValidator): | ||||||
|     message = _('This field must be unique for the "{date_field}" date.') |     message = _('This field must be unique for the "{date_field}" date.') | ||||||
| 
 | 
 | ||||||
|     def filter_queryset(self, attrs, queryset): |     def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||||||
|         value = attrs[self.field] |         value = attrs[self.field] | ||||||
|         date = attrs[self.date_field] |         date = attrs[self.date_field] | ||||||
| 
 | 
 | ||||||
|         filter_kwargs = {} |         filter_kwargs = {} | ||||||
|         filter_kwargs[self.field_name] = value |         filter_kwargs[field_name] = value | ||||||
|         filter_kwargs['%s__day' % self.date_field_name] = date.day |         filter_kwargs['%s__day' % date_field_name] = date.day | ||||||
|         filter_kwargs['%s__month' % self.date_field_name] = date.month |         filter_kwargs['%s__month' % date_field_name] = date.month | ||||||
|         filter_kwargs['%s__year' % self.date_field_name] = date.year |         filter_kwargs['%s__year' % date_field_name] = date.year | ||||||
|         return qs_filter(queryset, **filter_kwargs) |         return qs_filter(queryset, **filter_kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class UniqueForMonthValidator(BaseUniqueForValidator): | class UniqueForMonthValidator(BaseUniqueForValidator): | ||||||
|     message = _('This field must be unique for the "{date_field}" month.') |     message = _('This field must be unique for the "{date_field}" month.') | ||||||
| 
 | 
 | ||||||
|     def filter_queryset(self, attrs, queryset): |     def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||||||
|         value = attrs[self.field] |         value = attrs[self.field] | ||||||
|         date = attrs[self.date_field] |         date = attrs[self.date_field] | ||||||
| 
 | 
 | ||||||
|         filter_kwargs = {} |         filter_kwargs = {} | ||||||
|         filter_kwargs[self.field_name] = value |         filter_kwargs[field_name] = value | ||||||
|         filter_kwargs['%s__month' % self.date_field_name] = date.month |         filter_kwargs['%s__month' % date_field_name] = date.month | ||||||
|         return qs_filter(queryset, **filter_kwargs) |         return qs_filter(queryset, **filter_kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class UniqueForYearValidator(BaseUniqueForValidator): | class UniqueForYearValidator(BaseUniqueForValidator): | ||||||
|     message = _('This field must be unique for the "{date_field}" year.') |     message = _('This field must be unique for the "{date_field}" year.') | ||||||
| 
 | 
 | ||||||
|     def filter_queryset(self, attrs, queryset): |     def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||||||
|         value = attrs[self.field] |         value = attrs[self.field] | ||||||
|         date = attrs[self.date_field] |         date = attrs[self.date_field] | ||||||
| 
 | 
 | ||||||
|         filter_kwargs = {} |         filter_kwargs = {} | ||||||
|         filter_kwargs[self.field_name] = value |         filter_kwargs[field_name] = value | ||||||
|         filter_kwargs['%s__year' % self.date_field_name] = date.year |         filter_kwargs['%s__year' % date_field_name] = date.year | ||||||
|         return qs_filter(queryset, **filter_kwargs) |         return qs_filter(queryset, **filter_kwargs) | ||||||
|  |  | ||||||
|  | @ -361,8 +361,7 @@ class TestUniquenessTogetherValidation(TestCase): | ||||||
|         queryset = MockQueryset() |         queryset = MockQueryset() | ||||||
|         validator = UniqueTogetherValidator(queryset, fields=('race_name', |         validator = UniqueTogetherValidator(queryset, fields=('race_name', | ||||||
|                                                               'position')) |                                                               'position')) | ||||||
|         validator.instance = self.instance |         validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) | ||||||
|         validator.filter_queryset(attrs=data, queryset=queryset) |  | ||||||
|         assert queryset.called_with == {'race_name': 'bar', 'position': 1} |         assert queryset.called_with == {'race_name': 'bar', 'position': 1} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -586,4 +585,6 @@ class ValidatorsTests(TestCase): | ||||||
|         validator = BaseUniqueForValidator(queryset=object(), field='foo', |         validator = BaseUniqueForValidator(queryset=object(), field='foo', | ||||||
|                                            date_field='bar') |                                            date_field='bar') | ||||||
|         with pytest.raises(NotImplementedError): |         with pytest.raises(NotImplementedError): | ||||||
|             validator.filter_queryset(attrs=None, queryset=None) |             validator.filter_queryset( | ||||||
|  |                 attrs=None, queryset=None, field_name='', date_field_name='' | ||||||
|  |             ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user