mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
Uniqueness validation
This commit is contained in:
parent
ce04d59a53
commit
43fd5a8730
|
@ -150,6 +150,10 @@ class Field(object):
|
|||
messages.update(error_messages or {})
|
||||
self.error_messages = messages
|
||||
|
||||
for validator in validators:
|
||||
if getattr(validator, 'requires_context', False):
|
||||
validator.serializer_field = self
|
||||
|
||||
def bind(self, field_name, parent):
|
||||
"""
|
||||
Initializes the field name and parent for the field instance.
|
||||
|
|
|
@ -6,6 +6,7 @@ from django.core import validators
|
|||
from django.db import models
|
||||
from django.utils.text import capfirst
|
||||
from rest_framework.compat import clean_manytomany_helptext
|
||||
from rest_framework.validators import UniqueValidator
|
||||
import inspect
|
||||
|
||||
|
||||
|
@ -156,6 +157,10 @@ def get_field_kwargs(field_name, model_field):
|
|||
if validator is not validators.validate_slug
|
||||
]
|
||||
|
||||
if getattr(model_field, 'unique', False):
|
||||
validator = UniqueValidator(queryset=model_field.model._default_manager)
|
||||
validator_kwarg.append(validator)
|
||||
|
||||
max_digits = getattr(model_field, 'max_digits', None)
|
||||
if max_digits is not None:
|
||||
kwargs['max_digits'] = max_digits
|
||||
|
|
57
rest_framework/validators.py
Normal file
57
rest_framework/validators.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
|
||||
|
||||
class UniqueValidator:
|
||||
# Validators with `requires_context` will have the field instance
|
||||
# passed to them when the field is instantiated.
|
||||
requires_context = True
|
||||
|
||||
def __init__(self, queryset):
|
||||
self.queryset = queryset
|
||||
self.serializer_field = None
|
||||
|
||||
def get_queryset(self):
|
||||
return self.queryset.all()
|
||||
|
||||
def __call__(self, value):
|
||||
field = self.serializer_field
|
||||
|
||||
# Determine the model field name that the serializer field corresponds to.
|
||||
field_name = field.source_attrs[0] if field.source_attrs else field.field_name
|
||||
|
||||
# Determine the existing instance, if this is an update operation.
|
||||
instance = getattr(field.parent, 'instance', None)
|
||||
|
||||
# Ensure uniqueness.
|
||||
filter_kwargs = {field_name: value}
|
||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||
if instance:
|
||||
queryset = queryset.exclude(pk=instance.pk)
|
||||
if queryset.exists():
|
||||
raise ValidationError('This field must be unique.')
|
||||
|
||||
|
||||
class UniqueTogetherValidator:
|
||||
requires_context = True
|
||||
|
||||
def __init__(self, queryset, fields):
|
||||
self.queryset = queryset
|
||||
self.fields = fields
|
||||
self.serializer_field = None
|
||||
|
||||
def __call__(self, value):
|
||||
serializer = self.serializer_field
|
||||
|
||||
# Determine the existing instance, if this is an update operation.
|
||||
instance = getattr(serializer, 'instance', None)
|
||||
|
||||
# Ensure uniqueness.
|
||||
filter_kwargs = dict([
|
||||
(field_name, value[field_name]) for field_name in self.fields
|
||||
])
|
||||
queryset = self.get_queryset().filter(**filter_kwargs)
|
||||
if instance:
|
||||
queryset = queryset.exclude(pk=instance.pk)
|
||||
if queryset.exists():
|
||||
field_names = ' and '.join(self.fields)
|
||||
raise ValidationError('The fields %s must make a unique set.' % field_names)
|
35
tests/test_validators.py
Normal file
35
tests/test_validators.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class ExampleModel(models.Model):
|
||||
username = models.CharField(unique=True, max_length=100)
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ExampleModel
|
||||
|
||||
|
||||
class TestUniquenessValidation(TestCase):
|
||||
def setUp(self):
|
||||
self.instance = ExampleModel.objects.create(username='existing')
|
||||
|
||||
def test_is_not_unique(self):
|
||||
data = {'username': 'existing'}
|
||||
serializer = ExampleSerializer(data=data)
|
||||
assert not serializer.is_valid()
|
||||
assert serializer.errors == {'username': ['This field must be unique.']}
|
||||
|
||||
def test_is_unique(self):
|
||||
data = {'username': 'other'}
|
||||
serializer = ExampleSerializer(data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {'username': 'other'}
|
||||
|
||||
def test_updated_instance_excluded(self):
|
||||
data = {'username': 'existing'}
|
||||
serializer = ExampleSerializer(self.instance, data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {'username': 'existing'}
|
Loading…
Reference in New Issue
Block a user