mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-28 04:24:00 +03:00
03a0536cbd
I've noticed that this validator is using a per-character loop. Replacing it with a regex results in a pretty significant speedup. Here are results from my benchmark: String length Old implementation New implementation time (sec) time (sec) 1 2.833e-07 1.765e-07 10 5.885e-07 2.030e-07 100 3.598e-06 4.144e-07 1000 3.329e-05 2.463e-06 10000 0.0003338 2.449e-05 100000 0.003338 0.0002284 1000000 0.03333 0.002278 10000000 0.3389 0.02377 100000000 3.250 0.2365 For large strings, the speedups are more than an order of magnitude.
331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""
|
|
We perform uniqueness checks explicitly on the serializer class, rather
|
|
the using Django's `.full_clean()`.
|
|
|
|
This gives us better separation of concerns, allows us to use single-step
|
|
object creation, and makes it possible to switch between using the implicit
|
|
`ModelSerializer` class and an equivalent explicit `Serializer` class.
|
|
"""
|
|
import re
|
|
|
|
from django.db import DataError
|
|
from django.utils.translation import gettext_lazy as _
|
|
|
|
from rest_framework.exceptions import ValidationError
|
|
from rest_framework.utils.representation import smart_repr
|
|
|
|
|
|
# Robust filter and exist implementations. Ensures that queryset.exists() for
|
|
# an invalid value returns `False`, rather than raising an error.
|
|
# Refs https://github.com/encode/django-rest-framework/issues/3381
|
|
def qs_exists(queryset):
|
|
try:
|
|
return queryset.exists()
|
|
except (TypeError, ValueError, DataError):
|
|
return False
|
|
|
|
|
|
def qs_filter(queryset, **kwargs):
|
|
try:
|
|
return queryset.filter(**kwargs)
|
|
except (TypeError, ValueError, DataError):
|
|
return queryset.none()
|
|
|
|
|
|
class UniqueValidator:
|
|
"""
|
|
Validator that corresponds to `unique=True` on a model field.
|
|
|
|
Should be applied to an individual field on the serializer.
|
|
"""
|
|
message = _('This field must be unique.')
|
|
requires_context = True
|
|
|
|
def __init__(self, queryset, message=None, lookup='exact'):
|
|
self.queryset = queryset
|
|
self.message = message or self.message
|
|
self.lookup = lookup
|
|
|
|
def filter_queryset(self, value, queryset, field_name):
|
|
"""
|
|
Filter the queryset to all instances matching the given attribute.
|
|
"""
|
|
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
|
|
return qs_filter(queryset, **filter_kwargs)
|
|
|
|
def exclude_current_instance(self, queryset, instance):
|
|
"""
|
|
If an instance is being updated, then do not include
|
|
that instance itself as a uniqueness conflict.
|
|
"""
|
|
if instance is not None:
|
|
return queryset.exclude(pk=instance.pk)
|
|
return queryset
|
|
|
|
def __call__(self, value, serializer_field):
|
|
# Determine the underlying model field name. This may not be the
|
|
# same as the serializer field name if `source=<>` is set.
|
|
field_name = serializer_field.source_attrs[-1]
|
|
# Determine the existing instance, if this is an update operation.
|
|
instance = getattr(serializer_field.parent, 'instance', None)
|
|
|
|
queryset = self.queryset
|
|
queryset = self.filter_queryset(value, queryset, field_name)
|
|
queryset = self.exclude_current_instance(queryset, instance)
|
|
if qs_exists(queryset):
|
|
raise ValidationError(self.message, code='unique')
|
|
|
|
def __repr__(self):
|
|
return '<%s(queryset=%s)>' % (
|
|
self.__class__.__name__,
|
|
smart_repr(self.queryset)
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, self.__class__):
|
|
return NotImplemented
|
|
return (self.message == other.message
|
|
and self.requires_context == other.requires_context
|
|
and self.queryset == other.queryset
|
|
and self.lookup == other.lookup
|
|
)
|
|
|
|
|
|
class UniqueTogetherValidator:
|
|
"""
|
|
Validator that corresponds to `unique_together = (...)` on a model class.
|
|
|
|
Should be applied to the serializer class, not to an individual field.
|
|
"""
|
|
message = _('The fields {field_names} must make a unique set.')
|
|
missing_message = _('This field is required.')
|
|
requires_context = True
|
|
|
|
def __init__(self, queryset, fields, message=None):
|
|
self.queryset = queryset
|
|
self.fields = fields
|
|
self.message = message or self.message
|
|
|
|
def enforce_required_fields(self, attrs, serializer):
|
|
"""
|
|
The `UniqueTogetherValidator` always forces an implied 'required'
|
|
state on the fields it applies to.
|
|
"""
|
|
if serializer.instance is not None:
|
|
return
|
|
|
|
missing_items = {
|
|
field_name: self.missing_message
|
|
for field_name in self.fields
|
|
if serializer.fields[field_name].source not in attrs
|
|
}
|
|
if missing_items:
|
|
raise ValidationError(missing_items, code='required')
|
|
|
|
def filter_queryset(self, attrs, queryset, serializer):
|
|
"""
|
|
Filter the queryset to all instances matching the given attributes.
|
|
"""
|
|
# field names => field sources
|
|
sources = [
|
|
serializer.fields[field_name].source
|
|
for field_name in self.fields
|
|
]
|
|
|
|
# If this is an update, then any unprovided field should
|
|
# have it's value set based on the existing instance attribute.
|
|
if serializer.instance is not None:
|
|
for source in sources:
|
|
if source not in attrs:
|
|
attrs[source] = getattr(serializer.instance, source)
|
|
|
|
# Determine the filter keyword arguments and filter the queryset.
|
|
filter_kwargs = {
|
|
source: attrs[source]
|
|
for source in sources
|
|
}
|
|
return qs_filter(queryset, **filter_kwargs)
|
|
|
|
def exclude_current_instance(self, attrs, queryset, instance):
|
|
"""
|
|
If an instance is being updated, then do not include
|
|
that instance itself as a uniqueness conflict.
|
|
"""
|
|
if instance is not None:
|
|
return queryset.exclude(pk=instance.pk)
|
|
return queryset
|
|
|
|
def __call__(self, attrs, serializer):
|
|
self.enforce_required_fields(attrs, serializer)
|
|
queryset = self.queryset
|
|
queryset = self.filter_queryset(attrs, queryset, serializer)
|
|
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
|
|
|
|
checked_names = [
|
|
serializer.fields[field_name].source for field_name in self.fields
|
|
]
|
|
# Ignore validation if any field is None
|
|
if serializer.instance is None:
|
|
checked_values = [attrs[field_name] for field_name in checked_names]
|
|
else:
|
|
# Ignore validation if all field values are unchanged
|
|
checked_values = [
|
|
attrs[field_name]
|
|
for field_name in checked_names
|
|
if attrs[field_name] != getattr(serializer.instance, field_name)
|
|
]
|
|
|
|
if checked_values and None not in checked_values and qs_exists(queryset):
|
|
field_names = ', '.join(self.fields)
|
|
message = self.message.format(field_names=field_names)
|
|
raise ValidationError(message, code='unique')
|
|
|
|
def __repr__(self):
|
|
return '<%s(queryset=%s, fields=%s)>' % (
|
|
self.__class__.__name__,
|
|
smart_repr(self.queryset),
|
|
smart_repr(self.fields)
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, self.__class__):
|
|
return NotImplemented
|
|
return (self.message == other.message
|
|
and self.requires_context == other.requires_context
|
|
and self.missing_message == other.missing_message
|
|
and self.queryset == other.queryset
|
|
and self.fields == other.fields
|
|
)
|
|
|
|
|
|
class ProhibitSurrogateCharactersValidator:
|
|
_regex = re.compile(r'[\ud800-\udfff]')
|
|
|
|
message = _('Surrogate characters are not allowed: U+{code_point:X}.')
|
|
code = 'surrogate_characters_not_allowed'
|
|
|
|
def __call__(self, value):
|
|
if match := self._regex.search(str(value)):
|
|
message = self.message.format(code_point=ord(match.group()))
|
|
raise ValidationError(message, code=self.code)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, self.__class__):
|
|
return NotImplemented
|
|
return (self.message == other.message
|
|
and self.code == other.code
|
|
)
|
|
|
|
|
|
class BaseUniqueForValidator:
|
|
message = None
|
|
missing_message = _('This field is required.')
|
|
requires_context = True
|
|
|
|
def __init__(self, queryset, field, date_field, message=None):
|
|
self.queryset = queryset
|
|
self.field = field
|
|
self.date_field = date_field
|
|
self.message = message or self.message
|
|
|
|
def enforce_required_fields(self, attrs):
|
|
"""
|
|
The `UniqueFor<Range>Validator` classes always force an implied
|
|
'required' state on the fields they are applied to.
|
|
"""
|
|
missing_items = {
|
|
field_name: self.missing_message
|
|
for field_name in [self.field, self.date_field]
|
|
if field_name not in attrs
|
|
}
|
|
if missing_items:
|
|
raise ValidationError(missing_items, code='required')
|
|
|
|
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
|
|
raise NotImplementedError('`filter_queryset` must be implemented.')
|
|
|
|
def exclude_current_instance(self, attrs, queryset, instance):
|
|
"""
|
|
If an instance is being updated, then do not include
|
|
that instance itself as a uniqueness conflict.
|
|
"""
|
|
if instance is not None:
|
|
return queryset.exclude(pk=instance.pk)
|
|
return queryset
|
|
|
|
def __call__(self, attrs, serializer):
|
|
# Determine the underlying model field names. These may not be the
|
|
# same as the serializer field names if `source=<>` is set.
|
|
field_name = serializer.fields[self.field].source_attrs[-1]
|
|
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
|
|
|
|
self.enforce_required_fields(attrs)
|
|
queryset = self.queryset
|
|
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
|
|
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
|
|
if qs_exists(queryset):
|
|
message = self.message.format(date_field=self.date_field)
|
|
raise ValidationError({
|
|
self.field: message
|
|
}, code='unique')
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, self.__class__):
|
|
return NotImplemented
|
|
return (self.message == other.message
|
|
and self.missing_message == other.missing_message
|
|
and self.requires_context == other.requires_context
|
|
and self.queryset == other.queryset
|
|
and self.field == other.field
|
|
and self.date_field == other.date_field
|
|
)
|
|
|
|
def __repr__(self):
|
|
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
|
|
self.__class__.__name__,
|
|
smart_repr(self.queryset),
|
|
smart_repr(self.field),
|
|
smart_repr(self.date_field)
|
|
)
|
|
|
|
|
|
class UniqueForDateValidator(BaseUniqueForValidator):
|
|
message = _('This field must be unique for the "{date_field}" date.')
|
|
|
|
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
|
|
value = attrs[self.field]
|
|
date = attrs[self.date_field]
|
|
|
|
filter_kwargs = {}
|
|
filter_kwargs[field_name] = value
|
|
filter_kwargs['%s__day' % date_field_name] = date.day
|
|
filter_kwargs['%s__month' % date_field_name] = date.month
|
|
filter_kwargs['%s__year' % date_field_name] = date.year
|
|
return qs_filter(queryset, **filter_kwargs)
|
|
|
|
|
|
class UniqueForMonthValidator(BaseUniqueForValidator):
|
|
message = _('This field must be unique for the "{date_field}" month.')
|
|
|
|
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
|
|
value = attrs[self.field]
|
|
date = attrs[self.date_field]
|
|
|
|
filter_kwargs = {}
|
|
filter_kwargs[field_name] = value
|
|
filter_kwargs['%s__month' % date_field_name] = date.month
|
|
return qs_filter(queryset, **filter_kwargs)
|
|
|
|
|
|
class UniqueForYearValidator(BaseUniqueForValidator):
|
|
message = _('This field must be unique for the "{date_field}" year.')
|
|
|
|
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
|
|
value = attrs[self.field]
|
|
date = attrs[self.date_field]
|
|
|
|
filter_kwargs = {}
|
|
filter_kwargs[field_name] = value
|
|
filter_kwargs['%s__year' % date_field_name] = date.year
|
|
return qs_filter(queryset, **filter_kwargs)
|