From f387cd89da55ef88fcac504f5795ea9b591f3fba Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Nov 2014 12:21:27 +0000 Subject: [PATCH] Uniqueness constraints imply a forced 'required=True'. Refs #1945 --- docs/api-guide/validators.md | 10 ++++ rest_framework/validators.py | 97 +++++++++++++++++++++++++++++------- tests/test_validators.py | 11 ++++ 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index 6a0ef4ff4..bb073f576 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -93,6 +93,12 @@ The validator should be applied to *serializer classes*, like so: ) ] +--- + +**Note**: The `UniqueTogetherValidation` class always imposes an implicit constraint that all the fields it applies to are always treated as required. Fields with `default` values are an exception to this as they always supply a value even when omitted from user input. + +--- + ## UniqueForDateValidator ## UniqueForMonthValidator @@ -146,6 +152,10 @@ If you want the date field to be entirely hidden from the user, then use `Hidden --- +**Note**: The `UniqueForValidation` classes always imposes an implicit constraint that the fields they are applied to are always treated as required. Fields with `default` values are an exception to this as they always supply a value even when omitted from user input. + +--- + # Writing custom validators You can use any of Django's existing validators, or write your own custom validators. diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f3773f176..d7f847aaf 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -25,6 +25,10 @@ class UniqueValidator: self.message = message or self.message def set_context(self, serializer_field): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the underlying model field name. This may not be the # same as the serializer field name if `source=<>` is set. self.field_name = serializer_field.source_attrs[0] @@ -54,6 +58,7 @@ class UniqueTogetherValidator: Should be applied to the serializer class, not to an individual field. """ message = _('The fields {field_names} must make a unique set.') + missing_message = _('This field is required.') def __init__(self, queryset, fields, message=None): self.queryset = queryset @@ -62,17 +67,49 @@ class UniqueTogetherValidator: self.message = message or self.message def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the existing instance, if this is an update operation. self.instance = getattr(serializer, 'instance', None) - def __call__(self, attrs): - # Ensure uniqueness. + def enforce_required_fields(self, attrs): + """ + The `UniqueTogetherValidator` always forces an implied 'required' + state on the fields it applies to. + """ + missing = dict([ + (field_name, self.missing_message) + for field_name in self.fields + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) + + def filter_queryset(self, attrs, queryset): + """ + Filter the queryset to all instances matching the given attributes. + """ filter_kwargs = dict([ (field_name, attrs[field_name]) for field_name in self.fields ]) - queryset = self.queryset.filter(**filter_kwargs) + return queryset.filter(**filter_kwargs) + + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ if self.instance is not None: - queryset = queryset.exclude(pk=self.instance.pk) + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, attrs): + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): field_names = ', '.join(self.fields) raise ValidationError(self.message.format(field_names=field_names)) @@ -87,6 +124,7 @@ class UniqueTogetherValidator: class BaseUniqueForValidator: message = None + missing_message = _('This field is required.') def __init__(self, queryset, field, date_field, message=None): self.queryset = queryset @@ -95,6 +133,10 @@ class BaseUniqueForValidator: self.message = message or self.message def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the underlying model field names. These may not be the # same as the serializer field names if `source=<>` is set. self.field_name = serializer.fields[self.field].source_attrs[0] @@ -102,15 +144,36 @@ class BaseUniqueForValidator: # Determine the existing instance, if this is an update operation. self.instance = getattr(serializer, 'instance', None) - def get_filter_kwargs(self, attrs): - raise NotImplementedError('`get_filter_kwargs` must be implemented.') + def enforce_required_fields(self, attrs): + """ + The `UniqueForValidator` classes always force an implied + 'required' state on the fields they are applied to. + """ + missing = dict([ + (field_name, self.missing_message) + for field_name in [self.field, self.date_field] + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) + + def filter_queryset(self, attrs, queryset): + raise NotImplementedError('`filter_queryset` must be implemented.') + + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ + if self.instance is not None: + return queryset.exclude(pk=self.instance.pk) + return queryset def __call__(self, attrs): - filter_kwargs = self.get_filter_kwargs(attrs) - - queryset = self.queryset.filter(**filter_kwargs) - if self.instance is not None: - queryset = queryset.exclude(pk=self.instance.pk) + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) raise ValidationError({self.field: message}) @@ -127,7 +190,7 @@ class BaseUniqueForValidator: class UniqueForDateValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" date.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] @@ -136,30 +199,30 @@ class UniqueForDateValidator(BaseUniqueForValidator): filter_kwargs['%s__day' % self.date_field_name] = date.day filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs['%s__year' % self.date_field_name] = date.year - return filter_kwargs + return queryset.filter(**filter_kwargs) class UniqueForMonthValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" month.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} filter_kwargs[self.field_name] = value filter_kwargs['%s__month' % self.date_field_name] = date.month - return filter_kwargs + return queryset.filter(**filter_kwargs) class UniqueForYearValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" year.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} filter_kwargs[self.field_name] = value filter_kwargs['%s__year' % self.date_field_name] = date.year - return filter_kwargs + return queryset.filter(**filter_kwargs) diff --git a/tests/test_validators.py b/tests/test_validators.py index e6e0b23a8..86614b109 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -134,6 +134,17 @@ class TestUniquenessTogetherValidation(TestCase): 'position': 1 } + def test_unique_together_is_required(self): + """ + In a unique together validation, all fields are required. + """ + data = {'position': 2} + serializer = UniquenessTogetherSerializer(data=data, partial=True) + assert not serializer.is_valid() + assert serializer.errors == { + 'race_name': ['This field is required.'] + } + def test_ignore_excluded_fields(self): """ When model fields are not included in a serializer, then uniqueness