From d09a612863eb41dc2010aa6869a83a6137be1337 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 4 Dec 2019 16:43:31 -0800 Subject: [PATCH] Pass serializer instead of model to validator The `UniqueTogetherValidator` may need to access attributes on the serializer instead of just the model instance. For example, this is useful for handling field sources. --- rest_framework/validators.py | 23 +++++++++-------------- tests/test_validators.py | 7 ++++++- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 2907312a9..8c9a7f831 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -97,12 +97,12 @@ class UniqueTogetherValidator: self.serializer_field = None self.message = message or self.message - def enforce_required_fields(self, attrs, instance): + def enforce_required_fields(self, attrs, serializer): """ The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ - if instance is not None: + if serializer.instance is not None: return missing_items = { @@ -113,16 +113,16 @@ class UniqueTogetherValidator: if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset, instance): + def filter_queryset(self, attrs, queryset, serializer): """ Filter the queryset to all instances matching the given attributes. """ # If this is an update, then any unprovided field should # have it's value set based on the existing instance attribute. - if instance is not None: + if serializer.instance is not None: for field_name in self.fields: if field_name not in attrs: - attrs[field_name] = getattr(instance, field_name) + attrs[field_name] = getattr(serializer.instance, field_name) # Determine the filter keyword arguments and filter the queryset. filter_kwargs = { @@ -141,13 +141,10 @@ class UniqueTogetherValidator: return queryset def __call__(self, attrs, serializer): - # Determine the existing instance, if this is an update operation. - instance = getattr(serializer, 'instance', None) - - self.enforce_required_fields(attrs, instance) + self.enforce_required_fields(attrs, serializer) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset, instance) - queryset = self.exclude_current_instance(attrs, queryset, instance) + queryset = self.filter_queryset(attrs, queryset, serializer) + queryset = self.exclude_current_instance(attrs, queryset, serializer.instance) # Ignore validation if any field is None checked_values = [ @@ -207,13 +204,11 @@ class BaseUniqueForValidator: # same as the serializer field names if `source=<>` is set. field_name = serializer.fields[self.field].source_attrs[-1] date_field_name = serializer.fields[self.date_field].source_attrs[-1] - # Determine the existing instance, if this is an update operation. - instance = getattr(serializer, 'instance', None) self.enforce_required_fields(attrs) queryset = self.queryset queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) - queryset = self.exclude_current_instance(attrs, queryset, instance) + queryset = self.exclude_current_instance(attrs, queryset, serializer.instance) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) raise ValidationError({ diff --git a/tests/test_validators.py b/tests/test_validators.py index bb29a4305..5c4a62b31 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -357,11 +357,16 @@ class TestUniquenessTogetherValidation(TestCase): def filter(self, **kwargs): self.called_with = kwargs + class MockSerializer: + def __init__(self, instance): + self.instance = instance + data = {'race_name': 'bar'} queryset = MockQueryset() + serializer = MockSerializer(instance=self.instance) validator = UniqueTogetherValidator(queryset, fields=('race_name', 'position')) - validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) + validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer) assert queryset.called_with == {'race_name': 'bar', 'position': 1}