mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
UniqueTogetherValidator
This commit is contained in:
parent
43fd5a8730
commit
9805a085fb
|
@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import (
|
|||
get_relation_kwargs, get_nested_relation_kwargs,
|
||||
ClassLookupDict
|
||||
)
|
||||
from rest_framework.validators import UniqueTogetherValidator
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
|
@ -95,7 +96,7 @@ class BaseSerializer(Field):
|
|||
def is_valid(self, raise_exception=False):
|
||||
if not hasattr(self, '_validated_data'):
|
||||
try:
|
||||
self._validated_data = self.to_internal_value(self._initial_data)
|
||||
self._validated_data = self.run_validation(self._initial_data)
|
||||
except ValidationError as exc:
|
||||
self._validated_data = {}
|
||||
self._errors = exc.message_dict
|
||||
|
@ -223,15 +224,43 @@ class Serializer(BaseSerializer):
|
|||
return html.parse_html_dict(dictionary, prefix=self.field_name)
|
||||
return dictionary.get(self.field_name, empty)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
def run_validation(self, data=empty):
|
||||
"""
|
||||
Dict of native values <- Dict of primitive datatypes.
|
||||
We override the default `run_validation`, because the validation
|
||||
performed by validators and the `.validate()` method should
|
||||
be coerced into an error dictionary with a 'non_fields_error' key.
|
||||
"""
|
||||
if data is empty:
|
||||
if getattr(self.root, 'partial', False):
|
||||
raise SkipField()
|
||||
if self.required:
|
||||
self.fail('required')
|
||||
return self.get_default()
|
||||
|
||||
if data is None:
|
||||
if not self.allow_null:
|
||||
self.fail('null')
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValidationError({
|
||||
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
|
||||
})
|
||||
|
||||
value = self.to_internal_value(data)
|
||||
try:
|
||||
self.run_validators(value)
|
||||
self.validate(value)
|
||||
except ValidationError as exc:
|
||||
raise ValidationError({
|
||||
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
|
||||
})
|
||||
return value
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Dict of native values <- Dict of primitive datatypes.
|
||||
"""
|
||||
ret = {}
|
||||
errors = {}
|
||||
fields = [field for field in self.fields.values() if not field.read_only]
|
||||
|
@ -253,12 +282,7 @@ class Serializer(BaseSerializer):
|
|||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
try:
|
||||
return self.validate(ret)
|
||||
except ValidationError as exc:
|
||||
raise ValidationError({
|
||||
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
|
||||
})
|
||||
return ret
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
|
@ -355,6 +379,14 @@ class ModelSerializer(Serializer):
|
|||
})
|
||||
_related_class = PrimaryKeyRelatedField
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelSerializer, self).__init__(*args, **kwargs)
|
||||
if 'validators' not in kwargs:
|
||||
validators = self.get_unique_together_validators()
|
||||
if validators:
|
||||
self.validators.extend(validators)
|
||||
self._kwargs['validators'] = validators
|
||||
|
||||
def create(self, attrs):
|
||||
ModelClass = self.Meta.model
|
||||
|
||||
|
@ -381,6 +413,36 @@ class ModelSerializer(Serializer):
|
|||
setattr(obj, attr, value)
|
||||
obj.save()
|
||||
|
||||
def get_unique_together_validators(self):
|
||||
field_names = set([
|
||||
field.source for field in self.fields.values()
|
||||
if (field.source != '*') and ('.' not in field.source)
|
||||
])
|
||||
|
||||
validators = []
|
||||
model_class = self.Meta.model
|
||||
|
||||
for unique_together in model_class._meta.unique_together:
|
||||
if field_names.issuperset(set(unique_together)):
|
||||
validator = UniqueTogetherValidator(
|
||||
queryset=model_class._default_manager,
|
||||
fields=unique_together
|
||||
)
|
||||
validator.serializer_field = self
|
||||
validators.append(validator)
|
||||
|
||||
for parent_class in model_class._meta.parents.keys():
|
||||
for unique_together in parent_class._meta.unique_together:
|
||||
if field_names.issuperset(set(unique_together)):
|
||||
validator = UniqueTogetherValidator(
|
||||
queryset=parent_class._default_manager,
|
||||
fields=unique_together
|
||||
)
|
||||
validator.serializer_field = self
|
||||
validators.append(validator)
|
||||
|
||||
return validators
|
||||
|
||||
def _get_base_fields(self):
|
||||
declared_fields = copy.deepcopy(self._declared_fields)
|
||||
|
||||
|
|
|
@ -1,18 +1,26 @@
|
|||
"""
|
||||
We perform uniqueness checks explicitly on the serializer class, rather
|
||||
the using Django's `.full_clean()`.
|
||||
|
||||
This gives us better seperation of concerns, allows us to use single-step
|
||||
object creation, and makes it possible to switch between using the implicit
|
||||
`ModelSerializer` class and an equivelent explicit `Serializer` class.
|
||||
"""
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework.utils.representation import smart_repr
|
||||
|
||||
|
||||
class UniqueValidator:
|
||||
# Validators with `requires_context` will have the field instance
|
||||
# passed to them when the field is instantiated.
|
||||
requires_context = True
|
||||
message = _('This field must be unique.')
|
||||
|
||||
def __init__(self, queryset):
|
||||
self.queryset = queryset
|
||||
self.serializer_field = None
|
||||
|
||||
def get_queryset(self):
|
||||
return self.queryset.all()
|
||||
|
||||
def __call__(self, value):
|
||||
field = self.serializer_field
|
||||
|
||||
|
@ -24,15 +32,22 @@ class UniqueValidator:
|
|||
|
||||
# Ensure uniqueness.
|
||||
filter_kwargs = {field_name: value}
|
||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||
queryset = self.queryset.filter(**filter_kwargs)
|
||||
if instance:
|
||||
queryset = queryset.exclude(pk=instance.pk)
|
||||
if queryset.exists():
|
||||
raise ValidationError('This field must be unique.')
|
||||
raise ValidationError(self.message)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s(queryset=%s)>' % (
|
||||
self.__class__.__name__,
|
||||
smart_repr(self.queryset)
|
||||
)
|
||||
|
||||
|
||||
class UniqueTogetherValidator:
|
||||
requires_context = True
|
||||
message = _('The fields {field_names} must make a unique set.')
|
||||
|
||||
def __init__(self, queryset, fields):
|
||||
self.queryset = queryset
|
||||
|
@ -49,9 +64,16 @@ class UniqueTogetherValidator:
|
|||
filter_kwargs = dict([
|
||||
(field_name, value[field_name]) for field_name in self.fields
|
||||
])
|
||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||
queryset = self.queryset.filter(**filter_kwargs)
|
||||
if instance:
|
||||
queryset = queryset.exclude(pk=instance.pk)
|
||||
if queryset.exists():
|
||||
field_names = ' and '.join(self.fields)
|
||||
raise ValidationError('The fields %s must make a unique set.' % field_names)
|
||||
field_names = ', '.join(self.fields)
|
||||
raise ValidationError(self.message.format(field_names=field_names))
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s(queryset=%s, fields=%s)>' % (
|
||||
self.__class__.__name__,
|
||||
smart_repr(self.queryset),
|
||||
smart_repr(self.fields)
|
||||
)
|
||||
|
|
|
@ -3,33 +3,148 @@ from django.test import TestCase
|
|||
from rest_framework import serializers
|
||||
|
||||
|
||||
class ExampleModel(models.Model):
|
||||
def dedent(blocktext):
|
||||
return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
|
||||
|
||||
|
||||
# Tests for `UniqueValidator`
|
||||
# ---------------------------
|
||||
|
||||
class UniquenessModel(models.Model):
|
||||
username = models.CharField(unique=True, max_length=100)
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.ModelSerializer):
|
||||
class UniquenessSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ExampleModel
|
||||
model = UniquenessModel
|
||||
|
||||
|
||||
class TestUniquenessValidation(TestCase):
|
||||
def setUp(self):
|
||||
self.instance = ExampleModel.objects.create(username='existing')
|
||||
self.instance = UniquenessModel.objects.create(username='existing')
|
||||
|
||||
def test_repr(self):
|
||||
serializer = UniquenessSerializer()
|
||||
expected = dedent("""
|
||||
UniquenessSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>])
|
||||
""")
|
||||
assert repr(serializer) == expected
|
||||
|
||||
def test_is_not_unique(self):
|
||||
data = {'username': 'existing'}
|
||||
serializer = ExampleSerializer(data=data)
|
||||
serializer = UniquenessSerializer(data=data)
|
||||
assert not serializer.is_valid()
|
||||
assert serializer.errors == {'username': ['This field must be unique.']}
|
||||
|
||||
def test_is_unique(self):
|
||||
data = {'username': 'other'}
|
||||
serializer = ExampleSerializer(data=data)
|
||||
serializer = UniquenessSerializer(data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {'username': 'other'}
|
||||
|
||||
def test_updated_instance_excluded(self):
|
||||
data = {'username': 'existing'}
|
||||
serializer = ExampleSerializer(self.instance, data=data)
|
||||
serializer = UniquenessSerializer(self.instance, data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {'username': 'existing'}
|
||||
|
||||
|
||||
# Tests for `UniqueTogetherValidator`
|
||||
# -----------------------------------
|
||||
|
||||
class UniquenessTogetherModel(models.Model):
|
||||
race_name = models.CharField(max_length=100)
|
||||
position = models.IntegerField()
|
||||
|
||||
class Meta:
|
||||
unique_together = ('race_name', 'position')
|
||||
|
||||
|
||||
class UniquenessTogetherSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = UniquenessTogetherModel
|
||||
|
||||
|
||||
class TestUniquenessTogetherValidation(TestCase):
|
||||
def setUp(self):
|
||||
self.instance = UniquenessTogetherModel.objects.create(
|
||||
race_name='example',
|
||||
position=1
|
||||
)
|
||||
UniquenessTogetherModel.objects.create(
|
||||
race_name='example',
|
||||
position=2
|
||||
)
|
||||
UniquenessTogetherModel.objects.create(
|
||||
race_name='other',
|
||||
position=1
|
||||
)
|
||||
|
||||
def test_repr(self):
|
||||
serializer = UniquenessTogetherSerializer()
|
||||
expected = dedent("""
|
||||
UniquenessTogetherSerializer(validators=[<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
race_name = CharField(max_length=100)
|
||||
position = IntegerField()
|
||||
""")
|
||||
assert repr(serializer) == expected
|
||||
|
||||
def test_is_not_unique_together(self):
|
||||
"""
|
||||
Failing unique together validation should result in non field errors.
|
||||
"""
|
||||
data = {'race_name': 'example', 'position': 2}
|
||||
serializer = UniquenessTogetherSerializer(data=data)
|
||||
print serializer.validators
|
||||
assert not serializer.is_valid()
|
||||
assert serializer.errors == {
|
||||
'non_field_errors': [
|
||||
'The fields race_name, position must make a unique set.'
|
||||
]
|
||||
}
|
||||
|
||||
def test_is_unique_together(self):
|
||||
"""
|
||||
In a unique together validation, one field may be non-unique
|
||||
so long as the set as a whole is unique.
|
||||
"""
|
||||
data = {'race_name': 'other', 'position': 2}
|
||||
serializer = UniquenessTogetherSerializer(data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {
|
||||
'race_name': 'other',
|
||||
'position': 2
|
||||
}
|
||||
|
||||
def test_updated_instance_excluded_from_unique_together(self):
|
||||
"""
|
||||
When performing an update, the existing instance does not count
|
||||
as a match against uniqueness.
|
||||
"""
|
||||
data = {'race_name': 'example', 'position': 1}
|
||||
serializer = UniquenessTogetherSerializer(self.instance, data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {
|
||||
'race_name': 'example',
|
||||
'position': 1
|
||||
}
|
||||
|
||||
def test_ignore_exlcuded_fields(self):
|
||||
"""
|
||||
When model fields are not included in a serializer, then uniqueness
|
||||
validtors should not be added for that field.
|
||||
"""
|
||||
class ExcludedFieldSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = UniquenessTogetherModel
|
||||
fields = ('id', 'race_name',)
|
||||
serializer = ExcludedFieldSerializer()
|
||||
expected = dedent("""
|
||||
ExcludedFieldSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
race_name = CharField(max_length=100)
|
||||
""")
|
||||
assert repr(serializer) == expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user