mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
UniqueTogetherValidator
This commit is contained in:
parent
43fd5a8730
commit
9805a085fb
|
@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import (
|
||||||
get_relation_kwargs, get_nested_relation_kwargs,
|
get_relation_kwargs, get_nested_relation_kwargs,
|
||||||
ClassLookupDict
|
ClassLookupDict
|
||||||
)
|
)
|
||||||
|
from rest_framework.validators import UniqueTogetherValidator
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
@ -95,7 +96,7 @@ class BaseSerializer(Field):
|
||||||
def is_valid(self, raise_exception=False):
|
def is_valid(self, raise_exception=False):
|
||||||
if not hasattr(self, '_validated_data'):
|
if not hasattr(self, '_validated_data'):
|
||||||
try:
|
try:
|
||||||
self._validated_data = self.to_internal_value(self._initial_data)
|
self._validated_data = self.run_validation(self._initial_data)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
self._validated_data = {}
|
self._validated_data = {}
|
||||||
self._errors = exc.message_dict
|
self._errors = exc.message_dict
|
||||||
|
@ -223,15 +224,43 @@ class Serializer(BaseSerializer):
|
||||||
return html.parse_html_dict(dictionary, prefix=self.field_name)
|
return html.parse_html_dict(dictionary, prefix=self.field_name)
|
||||||
return dictionary.get(self.field_name, empty)
|
return dictionary.get(self.field_name, empty)
|
||||||
|
|
||||||
def to_internal_value(self, data):
|
def run_validation(self, data=empty):
|
||||||
"""
|
"""
|
||||||
Dict of native values <- Dict of primitive datatypes.
|
We override the default `run_validation`, because the validation
|
||||||
|
performed by validators and the `.validate()` method should
|
||||||
|
be coerced into an error dictionary with a 'non_fields_error' key.
|
||||||
"""
|
"""
|
||||||
|
if data is empty:
|
||||||
|
if getattr(self.root, 'partial', False):
|
||||||
|
raise SkipField()
|
||||||
|
if self.required:
|
||||||
|
self.fail('required')
|
||||||
|
return self.get_default()
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
if not self.allow_null:
|
||||||
|
self.fail('null')
|
||||||
|
return None
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise ValidationError({
|
raise ValidationError({
|
||||||
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
|
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
|
||||||
})
|
})
|
||||||
|
|
||||||
|
value = self.to_internal_value(data)
|
||||||
|
try:
|
||||||
|
self.run_validators(value)
|
||||||
|
self.validate(value)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise ValidationError({
|
||||||
|
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
|
||||||
|
})
|
||||||
|
return value
|
||||||
|
|
||||||
|
def to_internal_value(self, data):
|
||||||
|
"""
|
||||||
|
Dict of native values <- Dict of primitive datatypes.
|
||||||
|
"""
|
||||||
ret = {}
|
ret = {}
|
||||||
errors = {}
|
errors = {}
|
||||||
fields = [field for field in self.fields.values() if not field.read_only]
|
fields = [field for field in self.fields.values() if not field.read_only]
|
||||||
|
@ -253,12 +282,7 @@ class Serializer(BaseSerializer):
|
||||||
if errors:
|
if errors:
|
||||||
raise ValidationError(errors)
|
raise ValidationError(errors)
|
||||||
|
|
||||||
try:
|
return ret
|
||||||
return self.validate(ret)
|
|
||||||
except ValidationError as exc:
|
|
||||||
raise ValidationError({
|
|
||||||
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
|
|
||||||
})
|
|
||||||
|
|
||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
"""
|
"""
|
||||||
|
@ -355,6 +379,14 @@ class ModelSerializer(Serializer):
|
||||||
})
|
})
|
||||||
_related_class = PrimaryKeyRelatedField
|
_related_class = PrimaryKeyRelatedField
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ModelSerializer, self).__init__(*args, **kwargs)
|
||||||
|
if 'validators' not in kwargs:
|
||||||
|
validators = self.get_unique_together_validators()
|
||||||
|
if validators:
|
||||||
|
self.validators.extend(validators)
|
||||||
|
self._kwargs['validators'] = validators
|
||||||
|
|
||||||
def create(self, attrs):
|
def create(self, attrs):
|
||||||
ModelClass = self.Meta.model
|
ModelClass = self.Meta.model
|
||||||
|
|
||||||
|
@ -381,6 +413,36 @@ class ModelSerializer(Serializer):
|
||||||
setattr(obj, attr, value)
|
setattr(obj, attr, value)
|
||||||
obj.save()
|
obj.save()
|
||||||
|
|
||||||
|
def get_unique_together_validators(self):
|
||||||
|
field_names = set([
|
||||||
|
field.source for field in self.fields.values()
|
||||||
|
if (field.source != '*') and ('.' not in field.source)
|
||||||
|
])
|
||||||
|
|
||||||
|
validators = []
|
||||||
|
model_class = self.Meta.model
|
||||||
|
|
||||||
|
for unique_together in model_class._meta.unique_together:
|
||||||
|
if field_names.issuperset(set(unique_together)):
|
||||||
|
validator = UniqueTogetherValidator(
|
||||||
|
queryset=model_class._default_manager,
|
||||||
|
fields=unique_together
|
||||||
|
)
|
||||||
|
validator.serializer_field = self
|
||||||
|
validators.append(validator)
|
||||||
|
|
||||||
|
for parent_class in model_class._meta.parents.keys():
|
||||||
|
for unique_together in parent_class._meta.unique_together:
|
||||||
|
if field_names.issuperset(set(unique_together)):
|
||||||
|
validator = UniqueTogetherValidator(
|
||||||
|
queryset=parent_class._default_manager,
|
||||||
|
fields=unique_together
|
||||||
|
)
|
||||||
|
validator.serializer_field = self
|
||||||
|
validators.append(validator)
|
||||||
|
|
||||||
|
return validators
|
||||||
|
|
||||||
def _get_base_fields(self):
|
def _get_base_fields(self):
|
||||||
declared_fields = copy.deepcopy(self._declared_fields)
|
declared_fields = copy.deepcopy(self._declared_fields)
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,26 @@
|
||||||
|
"""
|
||||||
|
We perform uniqueness checks explicitly on the serializer class, rather
|
||||||
|
the using Django's `.full_clean()`.
|
||||||
|
|
||||||
|
This gives us better seperation of concerns, allows us to use single-step
|
||||||
|
object creation, and makes it possible to switch between using the implicit
|
||||||
|
`ModelSerializer` class and an equivelent explicit `Serializer` class.
|
||||||
|
"""
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
from rest_framework.utils.representation import smart_repr
|
||||||
|
|
||||||
|
|
||||||
class UniqueValidator:
|
class UniqueValidator:
|
||||||
# Validators with `requires_context` will have the field instance
|
# Validators with `requires_context` will have the field instance
|
||||||
# passed to them when the field is instantiated.
|
# passed to them when the field is instantiated.
|
||||||
requires_context = True
|
requires_context = True
|
||||||
|
message = _('This field must be unique.')
|
||||||
|
|
||||||
def __init__(self, queryset):
|
def __init__(self, queryset):
|
||||||
self.queryset = queryset
|
self.queryset = queryset
|
||||||
self.serializer_field = None
|
self.serializer_field = None
|
||||||
|
|
||||||
def get_queryset(self):
|
|
||||||
return self.queryset.all()
|
|
||||||
|
|
||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
field = self.serializer_field
|
field = self.serializer_field
|
||||||
|
|
||||||
|
@ -24,15 +32,22 @@ class UniqueValidator:
|
||||||
|
|
||||||
# Ensure uniqueness.
|
# Ensure uniqueness.
|
||||||
filter_kwargs = {field_name: value}
|
filter_kwargs = {field_name: value}
|
||||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
queryset = self.queryset.filter(**filter_kwargs)
|
||||||
if instance:
|
if instance:
|
||||||
queryset = queryset.exclude(pk=instance.pk)
|
queryset = queryset.exclude(pk=instance.pk)
|
||||||
if queryset.exists():
|
if queryset.exists():
|
||||||
raise ValidationError('This field must be unique.')
|
raise ValidationError(self.message)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<%s(queryset=%s)>' % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
smart_repr(self.queryset)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UniqueTogetherValidator:
|
class UniqueTogetherValidator:
|
||||||
requires_context = True
|
requires_context = True
|
||||||
|
message = _('The fields {field_names} must make a unique set.')
|
||||||
|
|
||||||
def __init__(self, queryset, fields):
|
def __init__(self, queryset, fields):
|
||||||
self.queryset = queryset
|
self.queryset = queryset
|
||||||
|
@ -49,9 +64,16 @@ class UniqueTogetherValidator:
|
||||||
filter_kwargs = dict([
|
filter_kwargs = dict([
|
||||||
(field_name, value[field_name]) for field_name in self.fields
|
(field_name, value[field_name]) for field_name in self.fields
|
||||||
])
|
])
|
||||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
queryset = self.queryset.filter(**filter_kwargs)
|
||||||
if instance:
|
if instance:
|
||||||
queryset = queryset.exclude(pk=instance.pk)
|
queryset = queryset.exclude(pk=instance.pk)
|
||||||
if queryset.exists():
|
if queryset.exists():
|
||||||
field_names = ' and '.join(self.fields)
|
field_names = ', '.join(self.fields)
|
||||||
raise ValidationError('The fields %s must make a unique set.' % field_names)
|
raise ValidationError(self.message.format(field_names=field_names))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<%s(queryset=%s, fields=%s)>' % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
smart_repr(self.queryset),
|
||||||
|
smart_repr(self.fields)
|
||||||
|
)
|
||||||
|
|
|
@ -3,33 +3,148 @@ from django.test import TestCase
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
class ExampleModel(models.Model):
|
def dedent(blocktext):
|
||||||
|
return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for `UniqueValidator`
|
||||||
|
# ---------------------------
|
||||||
|
|
||||||
|
class UniquenessModel(models.Model):
|
||||||
username = models.CharField(unique=True, max_length=100)
|
username = models.CharField(unique=True, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
class ExampleSerializer(serializers.ModelSerializer):
|
class UniquenessSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ExampleModel
|
model = UniquenessModel
|
||||||
|
|
||||||
|
|
||||||
class TestUniquenessValidation(TestCase):
|
class TestUniquenessValidation(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.instance = ExampleModel.objects.create(username='existing')
|
self.instance = UniquenessModel.objects.create(username='existing')
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
serializer = UniquenessSerializer()
|
||||||
|
expected = dedent("""
|
||||||
|
UniquenessSerializer():
|
||||||
|
id = IntegerField(label='ID', read_only=True)
|
||||||
|
username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>])
|
||||||
|
""")
|
||||||
|
assert repr(serializer) == expected
|
||||||
|
|
||||||
def test_is_not_unique(self):
|
def test_is_not_unique(self):
|
||||||
data = {'username': 'existing'}
|
data = {'username': 'existing'}
|
||||||
serializer = ExampleSerializer(data=data)
|
serializer = UniquenessSerializer(data=data)
|
||||||
assert not serializer.is_valid()
|
assert not serializer.is_valid()
|
||||||
assert serializer.errors == {'username': ['This field must be unique.']}
|
assert serializer.errors == {'username': ['This field must be unique.']}
|
||||||
|
|
||||||
def test_is_unique(self):
|
def test_is_unique(self):
|
||||||
data = {'username': 'other'}
|
data = {'username': 'other'}
|
||||||
serializer = ExampleSerializer(data=data)
|
serializer = UniquenessSerializer(data=data)
|
||||||
assert serializer.is_valid()
|
assert serializer.is_valid()
|
||||||
assert serializer.validated_data == {'username': 'other'}
|
assert serializer.validated_data == {'username': 'other'}
|
||||||
|
|
||||||
def test_updated_instance_excluded(self):
|
def test_updated_instance_excluded(self):
|
||||||
data = {'username': 'existing'}
|
data = {'username': 'existing'}
|
||||||
serializer = ExampleSerializer(self.instance, data=data)
|
serializer = UniquenessSerializer(self.instance, data=data)
|
||||||
assert serializer.is_valid()
|
assert serializer.is_valid()
|
||||||
assert serializer.validated_data == {'username': 'existing'}
|
assert serializer.validated_data == {'username': 'existing'}
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for `UniqueTogetherValidator`
|
||||||
|
# -----------------------------------
|
||||||
|
|
||||||
|
class UniquenessTogetherModel(models.Model):
|
||||||
|
race_name = models.CharField(max_length=100)
|
||||||
|
position = models.IntegerField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
unique_together = ('race_name', 'position')
|
||||||
|
|
||||||
|
|
||||||
|
class UniquenessTogetherSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = UniquenessTogetherModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestUniquenessTogetherValidation(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.instance = UniquenessTogetherModel.objects.create(
|
||||||
|
race_name='example',
|
||||||
|
position=1
|
||||||
|
)
|
||||||
|
UniquenessTogetherModel.objects.create(
|
||||||
|
race_name='example',
|
||||||
|
position=2
|
||||||
|
)
|
||||||
|
UniquenessTogetherModel.objects.create(
|
||||||
|
race_name='other',
|
||||||
|
position=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
serializer = UniquenessTogetherSerializer()
|
||||||
|
expected = dedent("""
|
||||||
|
UniquenessTogetherSerializer(validators=[<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]):
|
||||||
|
id = IntegerField(label='ID', read_only=True)
|
||||||
|
race_name = CharField(max_length=100)
|
||||||
|
position = IntegerField()
|
||||||
|
""")
|
||||||
|
assert repr(serializer) == expected
|
||||||
|
|
||||||
|
def test_is_not_unique_together(self):
|
||||||
|
"""
|
||||||
|
Failing unique together validation should result in non field errors.
|
||||||
|
"""
|
||||||
|
data = {'race_name': 'example', 'position': 2}
|
||||||
|
serializer = UniquenessTogetherSerializer(data=data)
|
||||||
|
print serializer.validators
|
||||||
|
assert not serializer.is_valid()
|
||||||
|
assert serializer.errors == {
|
||||||
|
'non_field_errors': [
|
||||||
|
'The fields race_name, position must make a unique set.'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_is_unique_together(self):
|
||||||
|
"""
|
||||||
|
In a unique together validation, one field may be non-unique
|
||||||
|
so long as the set as a whole is unique.
|
||||||
|
"""
|
||||||
|
data = {'race_name': 'other', 'position': 2}
|
||||||
|
serializer = UniquenessTogetherSerializer(data=data)
|
||||||
|
assert serializer.is_valid()
|
||||||
|
assert serializer.validated_data == {
|
||||||
|
'race_name': 'other',
|
||||||
|
'position': 2
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_updated_instance_excluded_from_unique_together(self):
|
||||||
|
"""
|
||||||
|
When performing an update, the existing instance does not count
|
||||||
|
as a match against uniqueness.
|
||||||
|
"""
|
||||||
|
data = {'race_name': 'example', 'position': 1}
|
||||||
|
serializer = UniquenessTogetherSerializer(self.instance, data=data)
|
||||||
|
assert serializer.is_valid()
|
||||||
|
assert serializer.validated_data == {
|
||||||
|
'race_name': 'example',
|
||||||
|
'position': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_ignore_exlcuded_fields(self):
|
||||||
|
"""
|
||||||
|
When model fields are not included in a serializer, then uniqueness
|
||||||
|
validtors should not be added for that field.
|
||||||
|
"""
|
||||||
|
class ExcludedFieldSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = UniquenessTogetherModel
|
||||||
|
fields = ('id', 'race_name',)
|
||||||
|
serializer = ExcludedFieldSerializer()
|
||||||
|
expected = dedent("""
|
||||||
|
ExcludedFieldSerializer():
|
||||||
|
id = IntegerField(label='ID', read_only=True)
|
||||||
|
race_name = CharField(max_length=100)
|
||||||
|
""")
|
||||||
|
assert repr(serializer) == expected
|
||||||
|
|
Loading…
Reference in New Issue
Block a user