diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 29cb5aec9..e964458f9 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -50,7 +50,19 @@ If set, this gives the default value that will be used for the field if no input The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned. -May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context). +May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `requires_context = True` attribute, then the serializer field will be passed as an argument. + +For example: + + class CurrentUserDefault: + """ + May be applied as a `default=...` value on a serializer field. + Returns the current user. + """ + requires_context = True + + def __call__(self, serializer_field): + return serializer_field.context['request'].user When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance. diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index 49685838a..009cd2468 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -291,13 +291,17 @@ To write a class-based validator, use the `__call__` method. Class-based validat message = 'This field must be a multiple of %d.' % self.base raise serializers.ValidationError(message) -#### Using `set_context()` +#### Accessing the context -In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator. +In some advanced cases you might want a validator to be passed the serializer +field it is being used with as additional context. You can do so by setting +a `requires_context = True` attribute on the validator. The `__call__` method +will then be called with the `serializer_field` +or `serializer` as an additional argument. - def set_context(self, serializer_field): - # Determine if this is an update or a create operation. - # In `__call__` we can then use that information to modify the validation behavior. - self.is_update = serializer_field.parent.instance is not None + requires_context = True + + def __call__(self, value, serializer_field): + ... [cite]: https://docs.djangoproject.com/en/stable/ref/validators/ diff --git a/rest_framework/fields.py b/rest_framework/fields.py index ea8f47b2d..9507914e8 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -5,6 +5,7 @@ import functools import inspect import re import uuid +import warnings from collections import OrderedDict from collections.abc import Mapping @@ -249,19 +250,30 @@ class CreateOnlyDefault: for create operations, but that do not return any value for update operations. """ + requires_context = True + def __init__(self, default): self.default = default - def set_context(self, serializer_field): - self.is_update = serializer_field.parent.instance is not None - if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update: - self.default.set_context(serializer_field) - - def __call__(self): - if self.is_update: + def __call__(self, serializer_field): + is_update = serializer_field.parent.instance is not None + if is_update: raise SkipField() if callable(self.default): - return self.default() + if hasattr(self.default, 'set_context'): + warnings.warn( + "Method `set_context` on defaults is deprecated and will " + "no longer be called starting with 3.12. Instead set " + "`requires_context = True` on the class, and accept the " + "context as an additional argument.", + DeprecationWarning, stacklevel=2 + ) + self.default.set_context(self) + + if getattr(self.default, 'requires_context', False): + return self.default(serializer_field) + else: + return self.default() return self.default def __repr__(self): @@ -269,11 +281,10 @@ class CreateOnlyDefault: class CurrentUserDefault: - def set_context(self, serializer_field): - self.user = serializer_field.context['request'].user + requires_context = True - def __call__(self): - return self.user + def __call__(self, serializer_field): + return serializer_field.context['request'].user def __repr__(self): return '%s()' % self.__class__.__name__ @@ -489,8 +500,20 @@ class Field: raise SkipField() if callable(self.default): if hasattr(self.default, 'set_context'): + warnings.warn( + "Method `set_context` on defaults is deprecated and will " + "no longer be called starting with 3.12. Instead set " + "`requires_context = True` on the class, and accept the " + "context as an additional argument.", + DeprecationWarning, stacklevel=2 + ) self.default.set_context(self) - return self.default() + + if getattr(self.default, 'requires_context', False): + return self.default(self) + else: + return self.default() + return self.default def validate_empty_values(self, data): @@ -551,10 +574,20 @@ class Field: errors = [] for validator in self.validators: if hasattr(validator, 'set_context'): + warnings.warn( + "Method `set_context` on validators is deprecated and will " + "no longer be called starting with 3.12. Instead set " + "`requires_context = True` on the class, and accept the " + "context as an additional argument.", + DeprecationWarning, stacklevel=2 + ) validator.set_context(self) try: - validator(value) + if getattr(validator, 'requires_context', False): + validator(value, self) + else: + validator(value) except ValidationError as exc: # If the validation error contains a mapping of fields to # errors then simply raise it immediately rather than diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 1cbe31b5e..2907312a9 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -37,6 +37,7 @@ class UniqueValidator: Should be applied to an individual field on the serializer. """ message = _('This field must be unique.') + requires_context = True def __init__(self, queryset, message=None, lookup='exact'): self.queryset = queryset @@ -44,37 +45,32 @@ class UniqueValidator: self.message = message or self.message self.lookup = lookup - def set_context(self, serializer_field): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the underlying model field name. This may not be the - # same as the serializer field name if `source=<>` is set. - self.field_name = serializer_field.source_attrs[-1] - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer_field.parent, 'instance', None) - - def filter_queryset(self, value, queryset): + def filter_queryset(self, value, queryset, field_name): """ Filter the queryset to all instances matching the given attribute. """ - filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} + filter_kwargs = {'%s__%s' % (field_name, self.lookup): value} return qs_filter(queryset, **filter_kwargs) - def exclude_current_instance(self, queryset): + def exclude_current_instance(self, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, value): + def __call__(self, value, serializer_field): + # Determine the underlying model field name. This may not be the + # same as the serializer field name if `source=<>` is set. + field_name = serializer_field.source_attrs[-1] + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer_field.parent, 'instance', None) + queryset = self.queryset - queryset = self.filter_queryset(value, queryset) - queryset = self.exclude_current_instance(queryset) + queryset = self.filter_queryset(value, queryset, field_name) + queryset = self.exclude_current_instance(queryset, instance) if qs_exists(queryset): raise ValidationError(self.message, code='unique') @@ -93,6 +89,7 @@ class UniqueTogetherValidator: """ message = _('The fields {field_names} must make a unique set.') missing_message = _('This field is required.') + requires_context = True def __init__(self, queryset, fields, message=None): self.queryset = queryset @@ -100,20 +97,12 @@ class UniqueTogetherValidator: self.serializer_field = None self.message = message or self.message - def set_context(self, serializer): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) - - def enforce_required_fields(self, attrs): + def enforce_required_fields(self, attrs, instance): """ The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ - if self.instance is not None: + if instance is not None: return missing_items = { @@ -124,16 +113,16 @@ class UniqueTogetherValidator: if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, instance): """ Filter the queryset to all instances matching the given attributes. """ # If this is an update, then any unprovided field should # have it's value set based on the existing instance attribute. - if self.instance is not None: + if instance is not None: for field_name in self.fields: if field_name not in attrs: - attrs[field_name] = getattr(self.instance, field_name) + attrs[field_name] = getattr(instance, field_name) # Determine the filter keyword arguments and filter the queryset. filter_kwargs = { @@ -142,20 +131,23 @@ class UniqueTogetherValidator: } return qs_filter(queryset, **filter_kwargs) - def exclude_current_instance(self, attrs, queryset): + def exclude_current_instance(self, attrs, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, attrs): - self.enforce_required_fields(attrs) + def __call__(self, attrs, serializer): + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer, 'instance', None) + + self.enforce_required_fields(attrs, instance) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset) - queryset = self.exclude_current_instance(attrs, queryset) + queryset = self.filter_queryset(attrs, queryset, instance) + queryset = self.exclude_current_instance(attrs, queryset, instance) # Ignore validation if any field is None checked_values = [ @@ -177,6 +169,7 @@ class UniqueTogetherValidator: class BaseUniqueForValidator: message = None missing_message = _('This field is required.') + requires_context = True def __init__(self, queryset, field, date_field, message=None): self.queryset = queryset @@ -184,18 +177,6 @@ class BaseUniqueForValidator: self.date_field = date_field self.message = message or self.message - def set_context(self, serializer): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the underlying model field names. These may not be the - # same as the serializer field names if `source=<>` is set. - self.field_name = serializer.fields[self.field].source_attrs[-1] - self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) - def enforce_required_fields(self, attrs): """ The `UniqueForValidator` classes always force an implied @@ -209,23 +190,30 @@ class BaseUniqueForValidator: if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): raise NotImplementedError('`filter_queryset` must be implemented.') - def exclude_current_instance(self, attrs, queryset): + def exclude_current_instance(self, attrs, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, attrs): + def __call__(self, attrs, serializer): + # Determine the underlying model field names. These may not be the + # same as the serializer field names if `source=<>` is set. + field_name = serializer.fields[self.field].source_attrs[-1] + date_field_name = serializer.fields[self.date_field].source_attrs[-1] + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer, 'instance', None) + self.enforce_required_fields(attrs) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset) - queryset = self.exclude_current_instance(attrs, queryset) + queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) + queryset = self.exclude_current_instance(attrs, queryset, instance) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) raise ValidationError({ @@ -244,39 +232,39 @@ class BaseUniqueForValidator: class UniqueForDateValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" date.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__day' % self.date_field_name] = date.day - filter_kwargs['%s__month' % self.date_field_name] = date.month - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs[field_name] = value + filter_kwargs['%s__day' % date_field_name] = date.day + filter_kwargs['%s__month' % date_field_name] = date.month + filter_kwargs['%s__year' % date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) class UniqueForMonthValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" month.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__month' % self.date_field_name] = date.month + filter_kwargs[field_name] = value + filter_kwargs['%s__month' % date_field_name] = date.month return qs_filter(queryset, **filter_kwargs) class UniqueForYearValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" year.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs[field_name] = value + filter_kwargs['%s__year' % date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) diff --git a/tests/test_validators.py b/tests/test_validators.py index fe31ba235..bb29a4305 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -361,8 +361,7 @@ class TestUniquenessTogetherValidation(TestCase): queryset = MockQueryset() validator = UniqueTogetherValidator(queryset, fields=('race_name', 'position')) - validator.instance = self.instance - validator.filter_queryset(attrs=data, queryset=queryset) + validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) assert queryset.called_with == {'race_name': 'bar', 'position': 1} @@ -586,4 +585,6 @@ class ValidatorsTests(TestCase): validator = BaseUniqueForValidator(queryset=object(), field='foo', date_field='bar') with pytest.raises(NotImplementedError): - validator.filter_queryset(attrs=None, queryset=None) + validator.filter_queryset( + attrs=None, queryset=None, field_name='', date_field_name='' + )