mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	UniqueTogetherValidator
This commit is contained in:
		
							parent
							
								
									43fd5a8730
								
							
						
					
					
						commit
						9805a085fb
					
				| 
						 | 
				
			
			@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import (
 | 
			
		|||
    get_relation_kwargs, get_nested_relation_kwargs,
 | 
			
		||||
    ClassLookupDict
 | 
			
		||||
)
 | 
			
		||||
from rest_framework.validators import UniqueTogetherValidator
 | 
			
		||||
import copy
 | 
			
		||||
import inspect
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +96,7 @@ class BaseSerializer(Field):
 | 
			
		|||
    def is_valid(self, raise_exception=False):
 | 
			
		||||
        if not hasattr(self, '_validated_data'):
 | 
			
		||||
            try:
 | 
			
		||||
                self._validated_data = self.to_internal_value(self._initial_data)
 | 
			
		||||
                self._validated_data = self.run_validation(self._initial_data)
 | 
			
		||||
            except ValidationError as exc:
 | 
			
		||||
                self._validated_data = {}
 | 
			
		||||
                self._errors = exc.message_dict
 | 
			
		||||
| 
						 | 
				
			
			@ -223,15 +224,43 @@ class Serializer(BaseSerializer):
 | 
			
		|||
            return html.parse_html_dict(dictionary, prefix=self.field_name)
 | 
			
		||||
        return dictionary.get(self.field_name, empty)
 | 
			
		||||
 | 
			
		||||
    def to_internal_value(self, data):
 | 
			
		||||
    def run_validation(self, data=empty):
 | 
			
		||||
        """
 | 
			
		||||
        Dict of native values <- Dict of primitive datatypes.
 | 
			
		||||
        We override the default `run_validation`, because the validation
 | 
			
		||||
        performed by validators and the `.validate()` method should
 | 
			
		||||
        be coerced into an error dictionary with a 'non_fields_error' key.
 | 
			
		||||
        """
 | 
			
		||||
        if data is empty:
 | 
			
		||||
            if getattr(self.root, 'partial', False):
 | 
			
		||||
                raise SkipField()
 | 
			
		||||
            if self.required:
 | 
			
		||||
                self.fail('required')
 | 
			
		||||
            return self.get_default()
 | 
			
		||||
 | 
			
		||||
        if data is None:
 | 
			
		||||
            if not self.allow_null:
 | 
			
		||||
                self.fail('null')
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if not isinstance(data, dict):
 | 
			
		||||
            raise ValidationError({
 | 
			
		||||
                api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        value = self.to_internal_value(data)
 | 
			
		||||
        try:
 | 
			
		||||
            self.run_validators(value)
 | 
			
		||||
            self.validate(value)
 | 
			
		||||
        except ValidationError as exc:
 | 
			
		||||
            raise ValidationError({
 | 
			
		||||
                api_settings.NON_FIELD_ERRORS_KEY: exc.messages
 | 
			
		||||
            })
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def to_internal_value(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Dict of native values <- Dict of primitive datatypes.
 | 
			
		||||
        """
 | 
			
		||||
        ret = {}
 | 
			
		||||
        errors = {}
 | 
			
		||||
        fields = [field for field in self.fields.values() if not field.read_only]
 | 
			
		||||
| 
						 | 
				
			
			@ -253,12 +282,7 @@ class Serializer(BaseSerializer):
 | 
			
		|||
        if errors:
 | 
			
		||||
            raise ValidationError(errors)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return self.validate(ret)
 | 
			
		||||
        except ValidationError as exc:
 | 
			
		||||
            raise ValidationError({
 | 
			
		||||
                api_settings.NON_FIELD_ERRORS_KEY: exc.messages
 | 
			
		||||
            })
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    def to_representation(self, instance):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -355,6 +379,14 @@ class ModelSerializer(Serializer):
 | 
			
		|||
    })
 | 
			
		||||
    _related_class = PrimaryKeyRelatedField
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super(ModelSerializer, self).__init__(*args, **kwargs)
 | 
			
		||||
        if 'validators' not in kwargs:
 | 
			
		||||
            validators = self.get_unique_together_validators()
 | 
			
		||||
            if validators:
 | 
			
		||||
                self.validators.extend(validators)
 | 
			
		||||
                self._kwargs['validators'] = validators
 | 
			
		||||
 | 
			
		||||
    def create(self, attrs):
 | 
			
		||||
        ModelClass = self.Meta.model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -381,6 +413,36 @@ class ModelSerializer(Serializer):
 | 
			
		|||
            setattr(obj, attr, value)
 | 
			
		||||
        obj.save()
 | 
			
		||||
 | 
			
		||||
    def get_unique_together_validators(self):
 | 
			
		||||
        field_names = set([
 | 
			
		||||
            field.source for field in self.fields.values()
 | 
			
		||||
            if (field.source != '*') and ('.' not in field.source)
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        validators = []
 | 
			
		||||
        model_class = self.Meta.model
 | 
			
		||||
 | 
			
		||||
        for unique_together in model_class._meta.unique_together:
 | 
			
		||||
            if field_names.issuperset(set(unique_together)):
 | 
			
		||||
                validator = UniqueTogetherValidator(
 | 
			
		||||
                    queryset=model_class._default_manager,
 | 
			
		||||
                    fields=unique_together
 | 
			
		||||
                )
 | 
			
		||||
                validator.serializer_field = self
 | 
			
		||||
                validators.append(validator)
 | 
			
		||||
 | 
			
		||||
        for parent_class in model_class._meta.parents.keys():
 | 
			
		||||
            for unique_together in parent_class._meta.unique_together:
 | 
			
		||||
                if field_names.issuperset(set(unique_together)):
 | 
			
		||||
                    validator = UniqueTogetherValidator(
 | 
			
		||||
                        queryset=parent_class._default_manager,
 | 
			
		||||
                        fields=unique_together
 | 
			
		||||
                    )
 | 
			
		||||
                    validator.serializer_field = self
 | 
			
		||||
                    validators.append(validator)
 | 
			
		||||
 | 
			
		||||
        return validators
 | 
			
		||||
 | 
			
		||||
    def _get_base_fields(self):
 | 
			
		||||
        declared_fields = copy.deepcopy(self._declared_fields)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,18 +1,26 @@
 | 
			
		|||
"""
 | 
			
		||||
We perform uniqueness checks explicitly on the serializer class, rather
 | 
			
		||||
the using Django's `.full_clean()`.
 | 
			
		||||
 | 
			
		||||
This gives us better seperation of concerns, allows us to use single-step
 | 
			
		||||
object creation, and makes it possible to switch between using the implicit
 | 
			
		||||
`ModelSerializer` class and an equivelent explicit `Serializer` class.
 | 
			
		||||
"""
 | 
			
		||||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
from rest_framework.utils.representation import smart_repr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UniqueValidator:
 | 
			
		||||
    # Validators with `requires_context` will have the field instance
 | 
			
		||||
    # passed to them when the field is instantiated.
 | 
			
		||||
    requires_context = True
 | 
			
		||||
    message = _('This field must be unique.')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, queryset):
 | 
			
		||||
        self.queryset = queryset
 | 
			
		||||
        self.serializer_field = None
 | 
			
		||||
 | 
			
		||||
    def get_queryset(self):
 | 
			
		||||
        return self.queryset.all()
 | 
			
		||||
 | 
			
		||||
    def __call__(self, value):
 | 
			
		||||
        field = self.serializer_field
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -24,15 +32,22 @@ class UniqueValidator:
 | 
			
		|||
 | 
			
		||||
        # Ensure uniqueness.
 | 
			
		||||
        filter_kwargs = {field_name: value}
 | 
			
		||||
        queryset = self.get_queryset().filter(**filter_kwargs)
 | 
			
		||||
        queryset = self.queryset.filter(**filter_kwargs)
 | 
			
		||||
        if instance:
 | 
			
		||||
            queryset = queryset.exclude(pk=instance.pk)
 | 
			
		||||
        if queryset.exists():
 | 
			
		||||
            raise ValidationError('This field must be unique.')
 | 
			
		||||
            raise ValidationError(self.message)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<%s(queryset=%s)>' % (
 | 
			
		||||
            self.__class__.__name__,
 | 
			
		||||
            smart_repr(self.queryset)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UniqueTogetherValidator:
 | 
			
		||||
    requires_context = True
 | 
			
		||||
    message = _('The fields {field_names} must make a unique set.')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, queryset, fields):
 | 
			
		||||
        self.queryset = queryset
 | 
			
		||||
| 
						 | 
				
			
			@ -49,9 +64,16 @@ class UniqueTogetherValidator:
 | 
			
		|||
        filter_kwargs = dict([
 | 
			
		||||
            (field_name, value[field_name]) for field_name in self.fields
 | 
			
		||||
        ])
 | 
			
		||||
        queryset = self.get_queryset().filter(**filter_kwargs)
 | 
			
		||||
        queryset = self.queryset.filter(**filter_kwargs)
 | 
			
		||||
        if instance:
 | 
			
		||||
            queryset = queryset.exclude(pk=instance.pk)
 | 
			
		||||
        if queryset.exists():
 | 
			
		||||
            field_names = ' and '.join(self.fields)
 | 
			
		||||
            raise ValidationError('The fields %s must make a unique set.' % field_names)
 | 
			
		||||
            field_names = ', '.join(self.fields)
 | 
			
		||||
            raise ValidationError(self.message.format(field_names=field_names))
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<%s(queryset=%s, fields=%s)>' % (
 | 
			
		||||
            self.__class__.__name__,
 | 
			
		||||
            smart_repr(self.queryset),
 | 
			
		||||
            smart_repr(self.fields)
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,33 +3,148 @@ from django.test import TestCase
 | 
			
		|||
from rest_framework import serializers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExampleModel(models.Model):
 | 
			
		||||
def dedent(blocktext):
 | 
			
		||||
    return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Tests for `UniqueValidator`
 | 
			
		||||
# ---------------------------
 | 
			
		||||
 | 
			
		||||
class UniquenessModel(models.Model):
 | 
			
		||||
    username = models.CharField(unique=True, max_length=100)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExampleSerializer(serializers.ModelSerializer):
 | 
			
		||||
class UniquenessSerializer(serializers.ModelSerializer):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = ExampleModel
 | 
			
		||||
        model = UniquenessModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestUniquenessValidation(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.instance = ExampleModel.objects.create(username='existing')
 | 
			
		||||
        self.instance = UniquenessModel.objects.create(username='existing')
 | 
			
		||||
 | 
			
		||||
    def test_repr(self):
 | 
			
		||||
        serializer = UniquenessSerializer()
 | 
			
		||||
        expected = dedent("""
 | 
			
		||||
            UniquenessSerializer():
 | 
			
		||||
                id = IntegerField(label='ID', read_only=True)
 | 
			
		||||
                username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>])
 | 
			
		||||
        """)
 | 
			
		||||
        assert repr(serializer) == expected
 | 
			
		||||
 | 
			
		||||
    def test_is_not_unique(self):
 | 
			
		||||
        data = {'username': 'existing'}
 | 
			
		||||
        serializer = ExampleSerializer(data=data)
 | 
			
		||||
        serializer = UniquenessSerializer(data=data)
 | 
			
		||||
        assert not serializer.is_valid()
 | 
			
		||||
        assert serializer.errors == {'username': ['This field must be unique.']}
 | 
			
		||||
 | 
			
		||||
    def test_is_unique(self):
 | 
			
		||||
        data = {'username': 'other'}
 | 
			
		||||
        serializer = ExampleSerializer(data=data)
 | 
			
		||||
        serializer = UniquenessSerializer(data=data)
 | 
			
		||||
        assert serializer.is_valid()
 | 
			
		||||
        assert serializer.validated_data == {'username': 'other'}
 | 
			
		||||
 | 
			
		||||
    def test_updated_instance_excluded(self):
 | 
			
		||||
        data = {'username': 'existing'}
 | 
			
		||||
        serializer = ExampleSerializer(self.instance, data=data)
 | 
			
		||||
        serializer = UniquenessSerializer(self.instance, data=data)
 | 
			
		||||
        assert serializer.is_valid()
 | 
			
		||||
        assert serializer.validated_data == {'username': 'existing'}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Tests for `UniqueTogetherValidator`
 | 
			
		||||
# -----------------------------------
 | 
			
		||||
 | 
			
		||||
class UniquenessTogetherModel(models.Model):
 | 
			
		||||
    race_name = models.CharField(max_length=100)
 | 
			
		||||
    position = models.IntegerField()
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = ('race_name', 'position')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UniquenessTogetherSerializer(serializers.ModelSerializer):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = UniquenessTogetherModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestUniquenessTogetherValidation(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.instance = UniquenessTogetherModel.objects.create(
 | 
			
		||||
            race_name='example',
 | 
			
		||||
            position=1
 | 
			
		||||
        )
 | 
			
		||||
        UniquenessTogetherModel.objects.create(
 | 
			
		||||
            race_name='example',
 | 
			
		||||
            position=2
 | 
			
		||||
        )
 | 
			
		||||
        UniquenessTogetherModel.objects.create(
 | 
			
		||||
            race_name='other',
 | 
			
		||||
            position=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_repr(self):
 | 
			
		||||
        serializer = UniquenessTogetherSerializer()
 | 
			
		||||
        expected = dedent("""
 | 
			
		||||
            UniquenessTogetherSerializer(validators=[<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]):
 | 
			
		||||
                id = IntegerField(label='ID', read_only=True)
 | 
			
		||||
                race_name = CharField(max_length=100)
 | 
			
		||||
                position = IntegerField()
 | 
			
		||||
        """)
 | 
			
		||||
        assert repr(serializer) == expected
 | 
			
		||||
 | 
			
		||||
    def test_is_not_unique_together(self):
 | 
			
		||||
        """
 | 
			
		||||
        Failing unique together validation should result in non field errors.
 | 
			
		||||
        """
 | 
			
		||||
        data = {'race_name': 'example', 'position': 2}
 | 
			
		||||
        serializer = UniquenessTogetherSerializer(data=data)
 | 
			
		||||
        print serializer.validators
 | 
			
		||||
        assert not serializer.is_valid()
 | 
			
		||||
        assert serializer.errors == {
 | 
			
		||||
            'non_field_errors': [
 | 
			
		||||
                'The fields race_name, position must make a unique set.'
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def test_is_unique_together(self):
 | 
			
		||||
        """
 | 
			
		||||
        In a unique together validation, one field may be non-unique
 | 
			
		||||
        so long as the set as a whole is unique.
 | 
			
		||||
        """
 | 
			
		||||
        data = {'race_name': 'other', 'position': 2}
 | 
			
		||||
        serializer = UniquenessTogetherSerializer(data=data)
 | 
			
		||||
        assert serializer.is_valid()
 | 
			
		||||
        assert serializer.validated_data == {
 | 
			
		||||
            'race_name': 'other',
 | 
			
		||||
            'position': 2
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def test_updated_instance_excluded_from_unique_together(self):
 | 
			
		||||
        """
 | 
			
		||||
        When performing an update, the existing instance does not count
 | 
			
		||||
        as a match against uniqueness.
 | 
			
		||||
        """
 | 
			
		||||
        data = {'race_name': 'example', 'position': 1}
 | 
			
		||||
        serializer = UniquenessTogetherSerializer(self.instance, data=data)
 | 
			
		||||
        assert serializer.is_valid()
 | 
			
		||||
        assert serializer.validated_data == {
 | 
			
		||||
            'race_name': 'example',
 | 
			
		||||
            'position': 1
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def test_ignore_exlcuded_fields(self):
 | 
			
		||||
        """
 | 
			
		||||
        When model fields are not included in a serializer, then uniqueness
 | 
			
		||||
        validtors should not be added for that field.
 | 
			
		||||
        """
 | 
			
		||||
        class ExcludedFieldSerializer(serializers.ModelSerializer):
 | 
			
		||||
            class Meta:
 | 
			
		||||
                model = UniquenessTogetherModel
 | 
			
		||||
                fields = ('id', 'race_name',)
 | 
			
		||||
        serializer = ExcludedFieldSerializer()
 | 
			
		||||
        expected = dedent("""
 | 
			
		||||
            ExcludedFieldSerializer():
 | 
			
		||||
                id = IntegerField(label='ID', read_only=True)
 | 
			
		||||
                race_name = CharField(max_length=100)
 | 
			
		||||
        """)
 | 
			
		||||
        assert repr(serializer) == expected
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user