Drop set_context() (#7062)

* Do not persist the context in validators

Fixes encode/django-rest-framework#5760

* Drop set_context() in favour of 'requires_context = True'
This commit is contained in:
Tom Christie 2019-12-03 11:16:27 +00:00 committed by GitHub
parent 9325c3f654
commit 070cff5a03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 93 deletions

View File

@ -50,7 +50,19 @@ If set, this gives the default value that will be used for the field if no input
The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned. The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned.
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context). May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `requires_context = True` attribute, then the serializer field will be passed as an argument.
For example:
class CurrentUserDefault:
"""
May be applied as a `default=...` value on a serializer field.
Returns the current user.
"""
requires_context = True
def __call__(self, serializer_field):
return serializer_field.context['request'].user
When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance. When serializing the instance, default will be used if the object attribute or dictionary key is not present in the instance.

View File

@ -291,13 +291,17 @@ To write a class-based validator, use the `__call__` method. Class-based validat
message = 'This field must be a multiple of %d.' % self.base message = 'This field must be a multiple of %d.' % self.base
raise serializers.ValidationError(message) raise serializers.ValidationError(message)
#### Using `set_context()` #### Accessing the context
In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator. In some advanced cases you might want a validator to be passed the serializer
field it is being used with as additional context. You can do so by setting
a `requires_context = True` attribute on the validator. The `__call__` method
will then be called with the `serializer_field`
or `serializer` as an additional argument.
def set_context(self, serializer_field): requires_context = True
# Determine if this is an update or a create operation.
# In `__call__` we can then use that information to modify the validation behavior. def __call__(self, value, serializer_field):
self.is_update = serializer_field.parent.instance is not None ...
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/ [cite]: https://docs.djangoproject.com/en/stable/ref/validators/

View File

@ -5,6 +5,7 @@ import functools
import inspect import inspect
import re import re
import uuid import uuid
import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
@ -249,19 +250,30 @@ class CreateOnlyDefault:
for create operations, but that do not return any value for update for create operations, but that do not return any value for update
operations. operations.
""" """
requires_context = True
def __init__(self, default): def __init__(self, default):
self.default = default self.default = default
def set_context(self, serializer_field): def __call__(self, serializer_field):
self.is_update = serializer_field.parent.instance is not None is_update = serializer_field.parent.instance is not None
if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update: if is_update:
self.default.set_context(serializer_field)
def __call__(self):
if self.is_update:
raise SkipField() raise SkipField()
if callable(self.default): if callable(self.default):
return self.default() if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
self.default.set_context(self)
if getattr(self.default, 'requires_context', False):
return self.default(serializer_field)
else:
return self.default()
return self.default return self.default
def __repr__(self): def __repr__(self):
@ -269,11 +281,10 @@ class CreateOnlyDefault:
class CurrentUserDefault: class CurrentUserDefault:
def set_context(self, serializer_field): requires_context = True
self.user = serializer_field.context['request'].user
def __call__(self): def __call__(self, serializer_field):
return self.user return serializer_field.context['request'].user
def __repr__(self): def __repr__(self):
return '%s()' % self.__class__.__name__ return '%s()' % self.__class__.__name__
@ -489,8 +500,20 @@ class Field:
raise SkipField() raise SkipField()
if callable(self.default): if callable(self.default):
if hasattr(self.default, 'set_context'): if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
self.default.set_context(self) self.default.set_context(self)
return self.default()
if getattr(self.default, 'requires_context', False):
return self.default(self)
else:
return self.default()
return self.default return self.default
def validate_empty_values(self, data): def validate_empty_values(self, data):
@ -551,10 +574,20 @@ class Field:
errors = [] errors = []
for validator in self.validators: for validator in self.validators:
if hasattr(validator, 'set_context'): if hasattr(validator, 'set_context'):
warnings.warn(
"Method `set_context` on validators is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
)
validator.set_context(self) validator.set_context(self)
try: try:
validator(value) if getattr(validator, 'requires_context', False):
validator(value, self)
else:
validator(value)
except ValidationError as exc: except ValidationError as exc:
# If the validation error contains a mapping of fields to # If the validation error contains a mapping of fields to
# errors then simply raise it immediately rather than # errors then simply raise it immediately rather than

View File

@ -37,6 +37,7 @@ class UniqueValidator:
Should be applied to an individual field on the serializer. Should be applied to an individual field on the serializer.
""" """
message = _('This field must be unique.') message = _('This field must be unique.')
requires_context = True
def __init__(self, queryset, message=None, lookup='exact'): def __init__(self, queryset, message=None, lookup='exact'):
self.queryset = queryset self.queryset = queryset
@ -44,37 +45,32 @@ class UniqueValidator:
self.message = message or self.message self.message = message or self.message
self.lookup = lookup self.lookup = lookup
def set_context(self, serializer_field): def filter_queryset(self, value, queryset, field_name):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None)
def filter_queryset(self, value, queryset):
""" """
Filter the queryset to all instances matching the given attribute. Filter the queryset to all instances matching the given attribute.
""" """
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset): def exclude_current_instance(self, queryset, instance):
""" """
If an instance is being updated, then do not include If an instance is being updated, then do not include
that instance itself as a uniqueness conflict. that instance itself as a uniqueness conflict.
""" """
if self.instance is not None: if instance is not None:
return queryset.exclude(pk=self.instance.pk) return queryset.exclude(pk=instance.pk)
return queryset return queryset
def __call__(self, value): def __call__(self, value, serializer_field):
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer_field.parent, 'instance', None)
queryset = self.queryset queryset = self.queryset
queryset = self.filter_queryset(value, queryset) queryset = self.filter_queryset(value, queryset, field_name)
queryset = self.exclude_current_instance(queryset) queryset = self.exclude_current_instance(queryset, instance)
if qs_exists(queryset): if qs_exists(queryset):
raise ValidationError(self.message, code='unique') raise ValidationError(self.message, code='unique')
@ -93,6 +89,7 @@ class UniqueTogetherValidator:
""" """
message = _('The fields {field_names} must make a unique set.') message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.') missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, fields, message=None): def __init__(self, queryset, fields, message=None):
self.queryset = queryset self.queryset = queryset
@ -100,20 +97,12 @@ class UniqueTogetherValidator:
self.serializer_field = None self.serializer_field = None
self.message = message or self.message self.message = message or self.message
def set_context(self, serializer): def enforce_required_fields(self, attrs, instance):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
def enforce_required_fields(self, attrs):
""" """
The `UniqueTogetherValidator` always forces an implied 'required' The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to. state on the fields it applies to.
""" """
if self.instance is not None: if instance is not None:
return return
missing_items = { missing_items = {
@ -124,16 +113,16 @@ class UniqueTogetherValidator:
if missing_items: if missing_items:
raise ValidationError(missing_items, code='required') raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset, instance):
""" """
Filter the queryset to all instances matching the given attributes. Filter the queryset to all instances matching the given attributes.
""" """
# If this is an update, then any unprovided field should # If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute. # have it's value set based on the existing instance attribute.
if self.instance is not None: if instance is not None:
for field_name in self.fields: for field_name in self.fields:
if field_name not in attrs: if field_name not in attrs:
attrs[field_name] = getattr(self.instance, field_name) attrs[field_name] = getattr(instance, field_name)
# Determine the filter keyword arguments and filter the queryset. # Determine the filter keyword arguments and filter the queryset.
filter_kwargs = { filter_kwargs = {
@ -142,20 +131,23 @@ class UniqueTogetherValidator:
} }
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset): def exclude_current_instance(self, attrs, queryset, instance):
""" """
If an instance is being updated, then do not include If an instance is being updated, then do not include
that instance itself as a uniqueness conflict. that instance itself as a uniqueness conflict.
""" """
if self.instance is not None: if instance is not None:
return queryset.exclude(pk=self.instance.pk) return queryset.exclude(pk=instance.pk)
return queryset return queryset
def __call__(self, attrs): def __call__(self, attrs, serializer):
self.enforce_required_fields(attrs) # Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs, instance)
queryset = self.queryset queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset) queryset = self.filter_queryset(attrs, queryset, instance)
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset, instance)
# Ignore validation if any field is None # Ignore validation if any field is None
checked_values = [ checked_values = [
@ -177,6 +169,7 @@ class UniqueTogetherValidator:
class BaseUniqueForValidator: class BaseUniqueForValidator:
message = None message = None
missing_message = _('This field is required.') missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, field, date_field, message=None): def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset self.queryset = queryset
@ -184,18 +177,6 @@ class BaseUniqueForValidator:
self.date_field = date_field self.date_field = date_field
self.message = message or self.message self.message = message or self.message
def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
def enforce_required_fields(self, attrs): def enforce_required_fields(self, attrs):
""" """
The `UniqueFor<Range>Validator` classes always force an implied The `UniqueFor<Range>Validator` classes always force an implied
@ -209,23 +190,30 @@ class BaseUniqueForValidator:
if missing_items: if missing_items:
raise ValidationError(missing_items, code='required') raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset, field_name, date_field_name):
raise NotImplementedError('`filter_queryset` must be implemented.') raise NotImplementedError('`filter_queryset` must be implemented.')
def exclude_current_instance(self, attrs, queryset): def exclude_current_instance(self, attrs, queryset, instance):
""" """
If an instance is being updated, then do not include If an instance is being updated, then do not include
that instance itself as a uniqueness conflict. that instance itself as a uniqueness conflict.
""" """
if self.instance is not None: if instance is not None:
return queryset.exclude(pk=self.instance.pk) return queryset.exclude(pk=instance.pk)
return queryset return queryset
def __call__(self, attrs): def __call__(self, attrs, serializer):
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs) self.enforce_required_fields(attrs)
queryset = self.queryset queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset) queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset, instance)
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({ raise ValidationError({
@ -244,39 +232,39 @@ class BaseUniqueForValidator:
class UniqueForDateValidator(BaseUniqueForValidator): class UniqueForDateValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" date.') message = _('This field must be unique for the "{date_field}" date.')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field] value = attrs[self.field]
date = attrs[self.date_field] date = attrs[self.date_field]
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day filter_kwargs['%s__day' % date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs['%s__month' % date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
class UniqueForMonthValidator(BaseUniqueForValidator): class UniqueForMonthValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" month.') message = _('This field must be unique for the "{date_field}" month.')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field] value = attrs[self.field]
date = attrs[self.date_field] date = attrs[self.date_field]
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs['%s__month' % date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
class UniqueForYearValidator(BaseUniqueForValidator): class UniqueForYearValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" year.') message = _('This field must be unique for the "{date_field}" year.')
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field] value = attrs[self.field]
date = attrs[self.date_field] date = attrs[self.date_field]
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)

View File

@ -361,8 +361,7 @@ class TestUniquenessTogetherValidation(TestCase):
queryset = MockQueryset() queryset = MockQueryset()
validator = UniqueTogetherValidator(queryset, fields=('race_name', validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position')) 'position'))
validator.instance = self.instance validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
validator.filter_queryset(attrs=data, queryset=queryset)
assert queryset.called_with == {'race_name': 'bar', 'position': 1} assert queryset.called_with == {'race_name': 'bar', 'position': 1}
@ -586,4 +585,6 @@ class ValidatorsTests(TestCase):
validator = BaseUniqueForValidator(queryset=object(), field='foo', validator = BaseUniqueForValidator(queryset=object(), field='foo',
date_field='bar') date_field='bar')
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
validator.filter_queryset(attrs=None, queryset=None) validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name=''
)