django-rest-framework/rest_framework/validators.py

243 lines
8.6 KiB
Python
Raw Normal View History

2014-09-29 14:23:02 +04:00
"""
We perform uniqueness checks explicitly on the serializer class, rather
the using Django's `.full_clean()`.
This gives us better separation of concerns, allows us to use single-step
2014-09-29 14:23:02 +04:00
object creation, and makes it possible to switch between using the implicit
`ModelSerializer` class and an equivelent explicit `Serializer` class.
"""
from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import ValidationError
2014-09-29 14:23:02 +04:00
from rest_framework.utils.representation import smart_repr
2014-09-29 12:24:03 +04:00
class UniqueValidator:
2014-10-09 13:11:44 +04:00
"""
Validator that corresponds to `unique=True` on a model field.
Should be applied to an individual field on the serializer.
2014-10-09 13:11:44 +04:00
"""
2014-09-29 14:23:02 +04:00
message = _('This field must be unique.')
2014-09-29 12:24:03 +04:00
2014-10-31 16:47:36 +03:00
def __init__(self, queryset, message=None):
2014-09-29 12:24:03 +04:00
self.queryset = queryset
self.serializer_field = None
2014-10-31 16:47:36 +03:00
self.message = message or self.message
2014-09-29 12:24:03 +04:00
def set_context(self, serializer_field):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[0]
2014-09-29 12:24:03 +04:00
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None)
2014-09-29 12:24:03 +04:00
def filter_queryset(self, value, queryset):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {self.field_name: value}
return queryset.filter(**filter_kwargs)
def exclude_current_instance(self, queryset):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
return queryset
def __call__(self, value):
queryset = self.queryset
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
2014-09-29 12:24:03 +04:00
if queryset.exists():
2014-09-29 14:23:02 +04:00
raise ValidationError(self.message)
def __repr__(self):
return '<%s(queryset=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset)
)
2014-09-29 12:24:03 +04:00
class UniqueTogetherValidator:
2014-10-09 13:11:44 +04:00
"""
Validator that corresponds to `unique_together = (...)` on a model class.
Should be applied to the serializer class, not to an individual field.
2014-10-09 13:11:44 +04:00
"""
2014-09-29 14:23:02 +04:00
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.')
2014-09-29 12:24:03 +04:00
2014-10-31 16:47:36 +03:00
def __init__(self, queryset, fields, message=None):
2014-09-29 12:24:03 +04:00
self.queryset = queryset
self.fields = fields
self.serializer_field = None
2014-10-31 16:47:36 +03:00
self.message = message or self.message
2014-09-29 12:24:03 +04:00
def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
2014-09-29 12:24:03 +04:00
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
2014-09-29 12:24:03 +04:00
def enforce_required_fields(self, attrs):
"""
The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to.
"""
missing = dict([
(field_name, self.missing_message)
for field_name in self.fields
if field_name not in attrs
])
if missing:
raise ValidationError(missing)
def filter_queryset(self, attrs, queryset):
"""
Filter the queryset to all instances matching the given attributes.
"""
2014-09-29 12:24:03 +04:00
filter_kwargs = dict([
(field_name, attrs[field_name]) for field_name in self.fields
2014-09-29 12:24:03 +04:00
])
return queryset.filter(**filter_kwargs)
def exclude_current_instance(self, attrs, queryset):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
return queryset
def __call__(self, attrs):
self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
2014-09-29 12:24:03 +04:00
if queryset.exists():
2014-09-29 14:23:02 +04:00
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names))
def __repr__(self):
return '<%s(queryset=%s, fields=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.fields)
)
class BaseUniqueForValidator:
message = None
missing_message = _('This field is required.')
2014-10-31 16:47:36 +03:00
def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset
self.field = field
self.date_field = date_field
2014-10-31 16:47:36 +03:00
self.message = message or self.message
def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
self.field_name = serializer.fields[self.field].source_attrs[0]
self.date_field_name = serializer.fields[self.date_field].source_attrs[0]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
def enforce_required_fields(self, attrs):
"""
The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to.
"""
missing = dict([
(field_name, self.missing_message)
for field_name in [self.field, self.date_field]
if field_name not in attrs
])
if missing:
raise ValidationError(missing)
def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
def exclude_current_instance(self, attrs, queryset):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
return queryset
def __call__(self, attrs):
self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
if queryset.exists():
message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message})
def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.field),
smart_repr(self.date_field)
)
class UniqueForDateValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" date.')
def filter_queryset(self, attrs, queryset):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year
return queryset.filter(**filter_kwargs)
class UniqueForMonthValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" month.')
def filter_queryset(self, attrs, queryset):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month
return queryset.filter(**filter_kwargs)
class UniqueForYearValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" year.')
def filter_queryset(self, attrs, queryset):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year
return queryset.filter(**filter_kwargs)