mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-02 20:54:42 +03:00
Uniqueness validation
This commit is contained in:
parent
ce04d59a53
commit
43fd5a8730
|
@ -150,6 +150,10 @@ class Field(object):
|
||||||
messages.update(error_messages or {})
|
messages.update(error_messages or {})
|
||||||
self.error_messages = messages
|
self.error_messages = messages
|
||||||
|
|
||||||
|
for validator in validators:
|
||||||
|
if getattr(validator, 'requires_context', False):
|
||||||
|
validator.serializer_field = self
|
||||||
|
|
||||||
def bind(self, field_name, parent):
|
def bind(self, field_name, parent):
|
||||||
"""
|
"""
|
||||||
Initializes the field name and parent for the field instance.
|
Initializes the field name and parent for the field instance.
|
||||||
|
|
|
@ -6,6 +6,7 @@ from django.core import validators
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.text import capfirst
|
from django.utils.text import capfirst
|
||||||
from rest_framework.compat import clean_manytomany_helptext
|
from rest_framework.compat import clean_manytomany_helptext
|
||||||
|
from rest_framework.validators import UniqueValidator
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,6 +157,10 @@ def get_field_kwargs(field_name, model_field):
|
||||||
if validator is not validators.validate_slug
|
if validator is not validators.validate_slug
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if getattr(model_field, 'unique', False):
|
||||||
|
validator = UniqueValidator(queryset=model_field.model._default_manager)
|
||||||
|
validator_kwarg.append(validator)
|
||||||
|
|
||||||
max_digits = getattr(model_field, 'max_digits', None)
|
max_digits = getattr(model_field, 'max_digits', None)
|
||||||
if max_digits is not None:
|
if max_digits is not None:
|
||||||
kwargs['max_digits'] = max_digits
|
kwargs['max_digits'] = max_digits
|
||||||
|
|
57
rest_framework/validators.py
Normal file
57
rest_framework/validators.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueValidator:
|
||||||
|
# Validators with `requires_context` will have the field instance
|
||||||
|
# passed to them when the field is instantiated.
|
||||||
|
requires_context = True
|
||||||
|
|
||||||
|
def __init__(self, queryset):
|
||||||
|
self.queryset = queryset
|
||||||
|
self.serializer_field = None
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
return self.queryset.all()
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
field = self.serializer_field
|
||||||
|
|
||||||
|
# Determine the model field name that the serializer field corresponds to.
|
||||||
|
field_name = field.source_attrs[0] if field.source_attrs else field.field_name
|
||||||
|
|
||||||
|
# Determine the existing instance, if this is an update operation.
|
||||||
|
instance = getattr(field.parent, 'instance', None)
|
||||||
|
|
||||||
|
# Ensure uniqueness.
|
||||||
|
filter_kwargs = {field_name: value}
|
||||||
|
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||||
|
if instance:
|
||||||
|
queryset = queryset.exclude(pk=instance.pk)
|
||||||
|
if queryset.exists():
|
||||||
|
raise ValidationError('This field must be unique.')
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueTogetherValidator:
|
||||||
|
requires_context = True
|
||||||
|
|
||||||
|
def __init__(self, queryset, fields):
|
||||||
|
self.queryset = queryset
|
||||||
|
self.fields = fields
|
||||||
|
self.serializer_field = None
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
serializer = self.serializer_field
|
||||||
|
|
||||||
|
# Determine the existing instance, if this is an update operation.
|
||||||
|
instance = getattr(serializer, 'instance', None)
|
||||||
|
|
||||||
|
# Ensure uniqueness.
|
||||||
|
filter_kwargs = dict([
|
||||||
|
(field_name, value[field_name]) for field_name in self.fields
|
||||||
|
])
|
||||||
|
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||||
|
if instance:
|
||||||
|
queryset = queryset.exclude(pk=instance.pk)
|
||||||
|
if queryset.exists():
|
||||||
|
field_names = ' and '.join(self.fields)
|
||||||
|
raise ValidationError('The fields %s must make a unique set.' % field_names)
|
35
tests/test_validators.py
Normal file
35
tests/test_validators.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from django.db import models
|
||||||
|
from django.test import TestCase
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleModel(models.Model):
|
||||||
|
username = models.CharField(unique=True, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = ExampleModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestUniquenessValidation(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.instance = ExampleModel.objects.create(username='existing')
|
||||||
|
|
||||||
|
def test_is_not_unique(self):
|
||||||
|
data = {'username': 'existing'}
|
||||||
|
serializer = ExampleSerializer(data=data)
|
||||||
|
assert not serializer.is_valid()
|
||||||
|
assert serializer.errors == {'username': ['This field must be unique.']}
|
||||||
|
|
||||||
|
def test_is_unique(self):
|
||||||
|
data = {'username': 'other'}
|
||||||
|
serializer = ExampleSerializer(data=data)
|
||||||
|
assert serializer.is_valid()
|
||||||
|
assert serializer.validated_data == {'username': 'other'}
|
||||||
|
|
||||||
|
def test_updated_instance_excluded(self):
|
||||||
|
data = {'username': 'existing'}
|
||||||
|
serializer = ExampleSerializer(self.instance, data=data)
|
||||||
|
assert serializer.is_valid()
|
||||||
|
assert serializer.validated_data == {'username': 'existing'}
|
Loading…
Reference in New Issue
Block a user