Drop set_context() (#7062)

* Do not persist the context in validators

Fixes encode/django-rest-framework#5760

* Drop set_context() in favour of 'requires_context = True'
This commit is contained in:
Tom Christie 2019-12-03 11:16:27 +00:00 committed by GitHub
parent 9325c3f654
commit 070cff5a03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 93 deletions

View File

@ -50,7 +50,19 @@ If set, this gives the default value that will be used for the field if no input
The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned.
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context).
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `requires_context = True` attribute, then the serializer field will be passed as an argument.
For example:
class CurrentUserDefault:
"""
May be applied as a `default=...` value on a serializer field.
Returns the current user.
"""
requires_context = True
def __call__(self, serializer_field):
return serializer_field.context['request'].user
When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance.

View File

@ -291,13 +291,17 @@ To write a class-based validator, use the `__call__` method. Class-based validat
message = 'This field must be a multiple of %d.' % self.base
raise serializers.ValidationError(message)
#### Using `set_context()`
#### Accessing the context
In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator.
In some advanced cases you might want a validator to be passed the serializer
field it is being used with as additional context. You can do so by setting
a `requires_context = True` attribute on the validator. The `__call__` method
will then be called with the `serializer_field`
or `serializer` as an additional argument.
def set_context(self, serializer_field):
# Determine if this is an update or a create operation.
# In `__call__` we can then use that information to modify the validation behavior.
self.is_update = serializer_field.parent.instance is not None
requires_context = True
def __call__(self, value, serializer_field):
...
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/

View File

@ -5,6 +5,7 @@ import functools
import inspect
import re
import uuid
import warnings
from collections import OrderedDict
from collections.abc import Mapping
@ -249,18 +250,29 @@ class CreateOnlyDefault:
for create operations, but that do not return any value for update
operations.
"""
requires_context = True
def __init__(self, default):
self.default = default
def set_context(self, serializer_field):
self.is_update = serializer_field.parent.instance is not None
if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update:
self.default.set_context(serializer_field)
def __call__(self):
if self.is_update:
def __call__(self, serializer_field):
is_update = serializer_field.parent.instance is not None
if is_update:
raise SkipField()
if callable(self.default):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
self.default.set_context(self)
if getattr(self.default, 'requires_context', False):
return self.default(serializer_field)
else:
return self.default()
return self.default
@ -269,11 +281,10 @@ class CreateOnlyDefault:
class CurrentUserDefault:
def set_context(self, serializer_field):
self.user = serializer_field.context['request'].user
requires_context = True
def __call__(self):
return self.user
def __call__(self, serializer_field):
return serializer_field.context['request'].user
def __repr__(self):
return '%s()' % self.__class__.__name__
@ -489,8 +500,20 @@ class Field:
raise SkipField()
if callable(self.default):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
self.default.set_context(self)
if getattr(self.default, 'requires_context', False):
return self.default(self)
else:
return self.default()
return self.default
def validate_empty_values(self, data):
@ -551,9 +574,19 @@ class Field:
errors = []
for validator in self.validators:
if hasattr(validator, 'set_context'):
warnings.warn(
"Method `set_context` on validators is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
validator.set_context(self)
try:
if getattr(validator, 'requires_context', False):
validator(value, self)
else:
validator(value)
except ValidationError as exc:
# If the validation error contains a mapping of fields to

View File

@ -37,6 +37,7 @@ class UniqueValidator:
Should be applied to an individual field on the serializer.
"""
message = _('This field must be unique.')
requires_context = True
def __init__(self, queryset, message=None, lookup='exact'):
self.queryset = queryset
@ -44,37 +45,32 @@ class UniqueValidator:
self.message = message or self.message
self.lookup = lookup
def set_context(self, serializer_field):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None)
def filter_queryset(self, value, queryset):
def filter_queryset(self, value, queryset, field_name):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset):
def exclude_current_instance(self, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, value):
def __call__(self, value, serializer_field):
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer_field.parent, 'instance', None)
queryset = self.queryset
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
queryset = self.filter_queryset(value, queryset, field_name)
queryset = self.exclude_current_instance(queryset, instance)
if qs_exists(queryset):
raise ValidationError(self.message, code='unique')
@ -93,6 +89,7 @@ class UniqueTogetherValidator:
"""
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, fields, message=None):
self.queryset = queryset
@ -100,20 +97,12 @@ class UniqueTogetherValidator:
self.serializer_field = None
self.message = message or self.message
def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
def enforce_required_fields(self, attrs):
def enforce_required_fields(self, attrs, instance):
"""
The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to.
"""
if self.instance is not None:
if instance is not None:
return
missing_items = {
@ -124,16 +113,16 @@ class UniqueTogetherValidator:
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, instance):
"""
Filter the queryset to all instances matching the given attributes.
"""
# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if self.instance is not None:
if instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(self.instance, field_name)
attrs[field_name] = getattr(instance, field_name)
# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
@ -142,20 +131,23 @@ class UniqueTogetherValidator:
}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset):
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, attrs):
self.enforce_required_fields(attrs)
def __call__(self, attrs, serializer):
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs, instance)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
queryset = self.filter_queryset(attrs, queryset, instance)
queryset = self.exclude_current_instance(attrs, queryset, instance)
# Ignore validation if any field is None
checked_values = [
@ -177,6 +169,7 @@ class UniqueTogetherValidator:
class BaseUniqueForValidator:
message = None
missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset
@ -184,18 +177,6 @@ class BaseUniqueForValidator:
self.date_field = date_field
self.message = message or self.message
def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
def enforce_required_fields(self, attrs):
"""
The `UniqueFor<Range>Validator` classes always force an implied
@ -209,23 +190,30 @@ class BaseUniqueForValidator:
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
raise NotImplementedError('`filter_queryset` must be implemented.')
def exclude_current_instance(self, attrs, queryset):
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, attrs):
def __call__(self, attrs, serializer):
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset, instance)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
@ -244,39 +232,39 @@ class BaseUniqueForValidator:
class UniqueForDateValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" date.')
def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs[field_name] = value
filter_kwargs['%s__day' % date_field_name] = date.day
filter_kwargs['%s__month' % date_field_name] = date.month
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
class UniqueForMonthValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" month.')
def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs[field_name] = value
filter_kwargs['%s__month' % date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs)
class UniqueForYearValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" year.')
def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs[field_name] = value
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)

View File

@ -361,8 +361,7 @@ class TestUniquenessTogetherValidation(TestCase):
queryset = MockQueryset()
validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position'))
validator.instance = self.instance
validator.filter_queryset(attrs=data, queryset=queryset)
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
@ -586,4 +585,6 @@ class ValidatorsTests(TestCase):
validator = BaseUniqueForValidator(queryset=object(), field='foo',
date_field='bar')
with pytest.raises(NotImplementedError):
validator.filter_queryset(attrs=None, queryset=None)
validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name=''
)