Pass serializer instead of model to validator

The `UniqueTogetherValidator` may need to access attributes on the
serializer instead of just the model instance. For example, this is
useful for handling field sources.
This commit is contained in:
Ryan P Kilby 2019-12-04 16:43:31 -08:00
parent 988f5b2034
commit d09a612863
2 changed files with 15 additions and 15 deletions

View File

@ -97,12 +97,12 @@ class UniqueTogetherValidator:
self.serializer_field = None self.serializer_field = None
self.message = message or self.message self.message = message or self.message
def enforce_required_fields(self, attrs, instance): def enforce_required_fields(self, attrs, serializer):
""" """
The `UniqueTogetherValidator` always forces an implied 'required' The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to. state on the fields it applies to.
""" """
if instance is not None: if serializer.instance is not None:
return return
missing_items = { missing_items = {
@ -113,16 +113,16 @@ class UniqueTogetherValidator:
if missing_items: if missing_items:
raise ValidationError(missing_items, code='required') raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset, instance): def filter_queryset(self, attrs, queryset, serializer):
""" """
Filter the queryset to all instances matching the given attributes. Filter the queryset to all instances matching the given attributes.
""" """
# If this is an update, then any unprovided field should # If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute. # have it's value set based on the existing instance attribute.
if instance is not None: if serializer.instance is not None:
for field_name in self.fields: for field_name in self.fields:
if field_name not in attrs: if field_name not in attrs:
attrs[field_name] = getattr(instance, field_name) attrs[field_name] = getattr(serializer.instance, field_name)
# Determine the filter keyword arguments and filter the queryset. # Determine the filter keyword arguments and filter the queryset.
filter_kwargs = { filter_kwargs = {
@ -141,13 +141,10 @@ class UniqueTogetherValidator:
return queryset return queryset
def __call__(self, attrs, serializer): def __call__(self, attrs, serializer):
# Determine the existing instance, if this is an update operation. self.enforce_required_fields(attrs, serializer)
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs, instance)
queryset = self.queryset queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, instance) queryset = self.filter_queryset(attrs, queryset, serializer)
queryset = self.exclude_current_instance(attrs, queryset, instance) queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
# Ignore validation if any field is None # Ignore validation if any field is None
checked_values = [ checked_values = [
@ -207,13 +204,11 @@ class BaseUniqueForValidator:
# same as the serializer field names if `source=<>` is set. # same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1] field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1] date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)
self.enforce_required_fields(attrs) self.enforce_required_fields(attrs)
queryset = self.queryset queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset, instance) queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({ raise ValidationError({

View File

@ -357,11 +357,16 @@ class TestUniquenessTogetherValidation(TestCase):
def filter(self, **kwargs): def filter(self, **kwargs):
self.called_with = kwargs self.called_with = kwargs
class MockSerializer:
def __init__(self, instance):
self.instance = instance
data = {'race_name': 'bar'} data = {'race_name': 'bar'}
queryset = MockQueryset() queryset = MockQueryset()
serializer = MockSerializer(instance=self.instance)
validator = UniqueTogetherValidator(queryset, fields=('race_name', validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position')) 'position'))
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer)
assert queryset.called_with == {'race_name': 'bar', 'position': 1} assert queryset.called_with == {'race_name': 'bar', 'position': 1}