UniqueTogetherValidator

This commit is contained in:
Tom Christie 2014-09-29 11:23:02 +01:00
parent 43fd5a8730
commit 9805a085fb
3 changed files with 223 additions and 24 deletions

View File

@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import (
get_relation_kwargs, get_nested_relation_kwargs, get_relation_kwargs, get_nested_relation_kwargs,
ClassLookupDict ClassLookupDict
) )
from rest_framework.validators import UniqueTogetherValidator
import copy import copy
import inspect import inspect
@ -95,7 +96,7 @@ class BaseSerializer(Field):
def is_valid(self, raise_exception=False): def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'): if not hasattr(self, '_validated_data'):
try: try:
self._validated_data = self.to_internal_value(self._initial_data) self._validated_data = self.run_validation(self._initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.message_dict self._errors = exc.message_dict
@ -223,15 +224,43 @@ class Serializer(BaseSerializer):
return html.parse_html_dict(dictionary, prefix=self.field_name) return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty) return dictionary.get(self.field_name, empty)
def to_internal_value(self, data): def run_validation(self, data=empty):
""" """
Dict of native values <- Dict of primitive datatypes. We override the default `run_validation`, because the validation
performed by validators and the `.validate()` method should
be coerced into an error dictionary with a 'non_fields_error' key.
""" """
if data is empty:
if getattr(self.root, 'partial', False):
raise SkipField()
if self.required:
self.fail('required')
return self.get_default()
if data is None:
if not self.allow_null:
self.fail('null')
return None
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValidationError({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
}) })
value = self.to_internal_value(data)
try:
self.run_validators(value)
self.validate(value)
except ValidationError as exc:
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
})
return value
def to_internal_value(self, data):
"""
Dict of native values <- Dict of primitive datatypes.
"""
ret = {} ret = {}
errors = {} errors = {}
fields = [field for field in self.fields.values() if not field.read_only] fields = [field for field in self.fields.values() if not field.read_only]
@ -253,12 +282,7 @@ class Serializer(BaseSerializer):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
try: return ret
return self.validate(ret)
except ValidationError as exc:
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
})
def to_representation(self, instance): def to_representation(self, instance):
""" """
@ -355,6 +379,14 @@ class ModelSerializer(Serializer):
}) })
_related_class = PrimaryKeyRelatedField _related_class = PrimaryKeyRelatedField
def __init__(self, *args, **kwargs):
super(ModelSerializer, self).__init__(*args, **kwargs)
if 'validators' not in kwargs:
validators = self.get_unique_together_validators()
if validators:
self.validators.extend(validators)
self._kwargs['validators'] = validators
def create(self, attrs): def create(self, attrs):
ModelClass = self.Meta.model ModelClass = self.Meta.model
@ -381,6 +413,36 @@ class ModelSerializer(Serializer):
setattr(obj, attr, value) setattr(obj, attr, value)
obj.save() obj.save()
def get_unique_together_validators(self):
field_names = set([
field.source for field in self.fields.values()
if (field.source != '*') and ('.' not in field.source)
])
validators = []
model_class = self.Meta.model
for unique_together in model_class._meta.unique_together:
if field_names.issuperset(set(unique_together)):
validator = UniqueTogetherValidator(
queryset=model_class._default_manager,
fields=unique_together
)
validator.serializer_field = self
validators.append(validator)
for parent_class in model_class._meta.parents.keys():
for unique_together in parent_class._meta.unique_together:
if field_names.issuperset(set(unique_together)):
validator = UniqueTogetherValidator(
queryset=parent_class._default_manager,
fields=unique_together
)
validator.serializer_field = self
validators.append(validator)
return validators
def _get_base_fields(self): def _get_base_fields(self):
declared_fields = copy.deepcopy(self._declared_fields) declared_fields = copy.deepcopy(self._declared_fields)

View File

@ -1,18 +1,26 @@
"""
We perform uniqueness checks explicitly on the serializer class, rather
the using Django's `.full_clean()`.
This gives us better seperation of concerns, allows us to use single-step
object creation, and makes it possible to switch between using the implicit
`ModelSerializer` class and an equivelent explicit `Serializer` class.
"""
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils.translation import ugettext_lazy as _
from rest_framework.utils.representation import smart_repr
class UniqueValidator: class UniqueValidator:
# Validators with `requires_context` will have the field instance # Validators with `requires_context` will have the field instance
# passed to them when the field is instantiated. # passed to them when the field is instantiated.
requires_context = True requires_context = True
message = _('This field must be unique.')
def __init__(self, queryset): def __init__(self, queryset):
self.queryset = queryset self.queryset = queryset
self.serializer_field = None self.serializer_field = None
def get_queryset(self):
return self.queryset.all()
def __call__(self, value): def __call__(self, value):
field = self.serializer_field field = self.serializer_field
@ -24,15 +32,22 @@ class UniqueValidator:
# Ensure uniqueness. # Ensure uniqueness.
filter_kwargs = {field_name: value} filter_kwargs = {field_name: value}
queryset = self.get_queryset().filter(**filter_kwargs) queryset = self.queryset.filter(**filter_kwargs)
if instance: if instance:
queryset = queryset.exclude(pk=instance.pk) queryset = queryset.exclude(pk=instance.pk)
if queryset.exists(): if queryset.exists():
raise ValidationError('This field must be unique.') raise ValidationError(self.message)
def __repr__(self):
return '<%s(queryset=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset)
)
class UniqueTogetherValidator: class UniqueTogetherValidator:
requires_context = True requires_context = True
message = _('The fields {field_names} must make a unique set.')
def __init__(self, queryset, fields): def __init__(self, queryset, fields):
self.queryset = queryset self.queryset = queryset
@ -49,9 +64,16 @@ class UniqueTogetherValidator:
filter_kwargs = dict([ filter_kwargs = dict([
(field_name, value[field_name]) for field_name in self.fields (field_name, value[field_name]) for field_name in self.fields
]) ])
queryset = self.get_queryset().filter(**filter_kwargs) queryset = self.queryset.filter(**filter_kwargs)
if instance: if instance:
queryset = queryset.exclude(pk=instance.pk) queryset = queryset.exclude(pk=instance.pk)
if queryset.exists(): if queryset.exists():
field_names = ' and '.join(self.fields) field_names = ', '.join(self.fields)
raise ValidationError('The fields %s must make a unique set.' % field_names) raise ValidationError(self.message.format(field_names=field_names))
def __repr__(self):
return '<%s(queryset=%s, fields=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.fields)
)

View File

@ -3,33 +3,148 @@ from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
class ExampleModel(models.Model): def dedent(blocktext):
return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
# Tests for `UniqueValidator`
# ---------------------------
class UniquenessModel(models.Model):
username = models.CharField(unique=True, max_length=100) username = models.CharField(unique=True, max_length=100)
class ExampleSerializer(serializers.ModelSerializer): class UniquenessSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ExampleModel model = UniquenessModel
class TestUniquenessValidation(TestCase): class TestUniquenessValidation(TestCase):
def setUp(self): def setUp(self):
self.instance = ExampleModel.objects.create(username='existing') self.instance = UniquenessModel.objects.create(username='existing')
def test_repr(self):
serializer = UniquenessSerializer()
expected = dedent("""
UniquenessSerializer():
id = IntegerField(label='ID', read_only=True)
username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>])
""")
assert repr(serializer) == expected
def test_is_not_unique(self): def test_is_not_unique(self):
data = {'username': 'existing'} data = {'username': 'existing'}
serializer = ExampleSerializer(data=data) serializer = UniquenessSerializer(data=data)
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'username': ['This field must be unique.']} assert serializer.errors == {'username': ['This field must be unique.']}
def test_is_unique(self): def test_is_unique(self):
data = {'username': 'other'} data = {'username': 'other'}
serializer = ExampleSerializer(data=data) serializer = UniquenessSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
assert serializer.validated_data == {'username': 'other'} assert serializer.validated_data == {'username': 'other'}
def test_updated_instance_excluded(self): def test_updated_instance_excluded(self):
data = {'username': 'existing'} data = {'username': 'existing'}
serializer = ExampleSerializer(self.instance, data=data) serializer = UniquenessSerializer(self.instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
assert serializer.validated_data == {'username': 'existing'} assert serializer.validated_data == {'username': 'existing'}
# Tests for `UniqueTogetherValidator`
# -----------------------------------
class UniquenessTogetherModel(models.Model):
race_name = models.CharField(max_length=100)
position = models.IntegerField()
class Meta:
unique_together = ('race_name', 'position')
class UniquenessTogetherSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
class TestUniquenessTogetherValidation(TestCase):
def setUp(self):
self.instance = UniquenessTogetherModel.objects.create(
race_name='example',
position=1
)
UniquenessTogetherModel.objects.create(
race_name='example',
position=2
)
UniquenessTogetherModel.objects.create(
race_name='other',
position=1
)
def test_repr(self):
serializer = UniquenessTogetherSerializer()
expected = dedent("""
UniquenessTogetherSerializer(validators=[<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]):
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100)
position = IntegerField()
""")
assert repr(serializer) == expected
def test_is_not_unique_together(self):
"""
Failing unique together validation should result in non field errors.
"""
data = {'race_name': 'example', 'position': 2}
serializer = UniquenessTogetherSerializer(data=data)
print serializer.validators
assert not serializer.is_valid()
assert serializer.errors == {
'non_field_errors': [
'The fields race_name, position must make a unique set.'
]
}
def test_is_unique_together(self):
"""
In a unique together validation, one field may be non-unique
so long as the set as a whole is unique.
"""
data = {'race_name': 'other', 'position': 2}
serializer = UniquenessTogetherSerializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
'race_name': 'other',
'position': 2
}
def test_updated_instance_excluded_from_unique_together(self):
"""
When performing an update, the existing instance does not count
as a match against uniqueness.
"""
data = {'race_name': 'example', 'position': 1}
serializer = UniquenessTogetherSerializer(self.instance, data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
'race_name': 'example',
'position': 1
}
def test_ignore_exlcuded_fields(self):
"""
When model fields are not included in a serializer, then uniqueness
validtors should not be added for that field.
"""
class ExcludedFieldSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
fields = ('id', 'race_name',)
serializer = ExcludedFieldSerializer()
expected = dedent("""
ExcludedFieldSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100)
""")
assert repr(serializer) == expected