unique_for_date/unique_for_month/unique_for_year

This commit is contained in:
Tom Christie 2014-10-28 16:21:49 +00:00
parent 702f47700d
commit 9ebaabd6eb
4 changed files with 194 additions and 8 deletions

View File

@ -92,13 +92,35 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value dictionary[keys[-1]] = value
class CreateOnlyDefault:
"""
This class may be used to provide default values that are only used
for create operations, but that do not return any value for update
operations.
"""
def __init__(self, default):
self.default = default
def set_context(self, serializer_field):
self.is_update = serializer_field.parent.instance is not None
def __call__(self):
if self.is_update:
raise SkipField()
if callable(self.default):
return self.default()
return self.default
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, repr(self.default))
class SkipField(Exception): class SkipField(Exception):
pass pass
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = ( MISSING_ERROR_MESSAGE = (
@ -132,7 +154,6 @@ class Field(object):
# Some combinations of keyword arguments do not make sense. # Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), NOT_READ_ONLY_REQUIRED assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
assert not (read_only and self.__class__ == Field), USE_READONLYFIELD assert not (read_only and self.__class__ == Field), USE_READONLYFIELD
@ -230,7 +251,9 @@ class Field(object):
""" """
if self.default is empty: if self.default is empty:
raise SkipField() raise SkipField()
if is_simple_callable(self.default): if callable(self.default):
if hasattr(self.default, 'set_context'):
self.default.set_context(self)
return self.default() return self.default()
return self.default return self.default
@ -244,6 +267,9 @@ class Field(object):
May raise `SkipField` if the field should not be included in the May raise `SkipField` if the field should not be included in the
validated data. validated data.
""" """
if self.read_only:
return self.get_default()
if data is empty: if data is empty:
if getattr(self.root, 'partial', False): if getattr(self.root, 'partial', False):
raise SkipField() raise SkipField()
@ -1033,6 +1059,28 @@ class ReadOnlyField(Field):
return value return value
class HiddenField(Field):
"""
A hidden field does not take input from the user, or present any output,
but it does populate a field in `validated_data`, based on its default
value. This is particularly useful when we have a `unique_for_date`
constrain on a pair of fields, as we need some way to include the date in
the validated data.
"""
def __init__(self, **kwargs):
assert 'default' in kwargs, 'default is a required argument.'
kwargs['write_only'] = True
super(HiddenField, self).__init__(**kwargs)
def get_value(self, dictionary):
# We always use the default value for `HiddenField`.
# User input is never provided or accepted.
return empty
def to_internal_value(self, data):
return data
class SerializerMethodField(Field): class SerializerMethodField(Field):
""" """
A read-only field that get its representation from calling a method on the A read-only field that get its representation from calling a method on the

View File

@ -12,6 +12,7 @@ response content is handled by parsers and renderers.
""" """
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.db.models.fields import FieldDoesNotExist
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -368,7 +369,10 @@ class Serializer(BaseSerializer):
""" """
ret = {} ret = {}
errors = ReturnDict(serializer=self) errors = ReturnDict(serializer=self)
fields = [field for field in self.fields.values() if not field.read_only] fields = [
field for field in self.fields.values()
if (not field.read_only) or (field.default is not empty)
]
for field in fields: for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None) validate_method = getattr(self, 'validate_' + field.field_name, None)
@ -517,7 +521,7 @@ class ModelSerializer(Serializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModelSerializer, self).__init__(*args, **kwargs) super(ModelSerializer, self).__init__(*args, **kwargs)
if 'validators' not in kwargs: if 'validators' not in kwargs:
validators = self.get_unique_together_validators() validators = self.get_default_validators()
if validators: if validators:
self.validators.extend(validators) self.validators.extend(validators)
self._kwargs['validators'] = validators self._kwargs['validators'] = validators
@ -572,7 +576,7 @@ class ModelSerializer(Serializer):
instance.save() instance.save()
return instance return instance
def get_unique_together_validators(self): def get_default_validators(self):
field_names = set([ field_names = set([
field.source for field in self.fields.values() field.source for field in self.fields.values()
if (field.source != '*') and ('.' not in field.source) if (field.source != '*') and ('.' not in field.source)
@ -592,6 +596,7 @@ class ModelSerializer(Serializer):
) )
validators.append(validator) validators.append(validator)
# Add any unique_for_date/unique_for_month/unique_for_year constraints.
info = model_meta.get_field_info(model_class) info = model_meta.get_field_info(model_class)
for field_name, field in info.fields_and_pk.items(): for field_name, field in info.fields_and_pk.items():
if field.unique_for_date and field_name in field_names: if field.unique_for_date and field_name in field_names:
@ -637,7 +642,7 @@ class ModelSerializer(Serializer):
# Retrieve metadata about fields & relationships on the model class. # Retrieve metadata about fields & relationships on the model class.
info = model_meta.get_field_info(model) info = model_meta.get_field_info(model)
# Use the default set of fields if none is supplied explicitly. # Use the default set of field names if none is supplied explicitly.
if fields is None: if fields is None:
fields = self._get_default_field_names(declared_fields, info) fields = self._get_default_field_names(declared_fields, info)
exclude = getattr(self.Meta, 'exclude', None) exclude = getattr(self.Meta, 'exclude', None)
@ -645,6 +650,72 @@ class ModelSerializer(Serializer):
for field_name in exclude: for field_name in exclude:
fields.remove(field_name) fields.remove(field_name)
# Determine the set of model fields, and the fields that they map to.
# We actually only need this to deal with the slightly awkward case
# of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`.
model_field_mapping = {}
for field_name in fields:
if field_name in declared_fields:
field = declared_fields[field_name]
source = field.source or field_name
else:
try:
source = extra_kwargs[field_name]['source']
except KeyError:
source = field_name
# Model fields will always have a simple source mapping,
# they can't be nested attribute lookups.
if '.' not in source and source != '*':
model_field_mapping[source] = field_name
# Determine if we need any additional `HiddenField` or extra keyword
# arguments to deal with `unique_for` dates that are required to
# be in the input data in order to validate it.
unique_fields = {}
for model_field_name, field_name in model_field_mapping.items():
try:
model_field = model._meta.get_field(model_field_name)
except FieldDoesNotExist:
continue
# Deal with each of the `unique_for_*` cases.
for date_field_name in (
model_field.unique_for_date,
model_field.unique_for_month,
model_field.unique_for_year
):
if date_field_name is None:
continue
# Get the model field that is refered too.
date_field = model._meta.get_field(date_field_name)
if date_field.auto_now_add:
default = CreateOnlyDefault(timezone.now)
elif date_field.auto_now:
default = timezone.now
elif date_field.has_default():
default = model_field.default
else:
default = empty
if date_field_name in model_field_mapping:
# The corresponding date field is present in the serializer
if date_field_name not in extra_kwargs:
extra_kwargs[date_field_name] = {}
if default is empty:
if 'required' not in extra_kwargs[date_field_name]:
extra_kwargs[date_field_name]['required'] = True
else:
if 'default' not in extra_kwargs[date_field_name]:
extra_kwargs[date_field_name]['default'] = default
else:
# The corresponding date field is not present in the,
# serializer. We have a default to use for the date, so
# add in a hidden field that populates it.
unique_fields[date_field_name] = HiddenField(default=default)
# Now determine the fields that should be included on the serializer.
for field_name in fields: for field_name in fields:
if field_name in declared_fields: if field_name in declared_fields:
# Field is explicitly declared on the class, use that. # Field is explicitly declared on the class, use that.
@ -723,6 +794,9 @@ class ModelSerializer(Serializer):
# Create the serializer field. # Create the serializer field.
ret[field_name] = field_cls(**kwargs) ret[field_name] = field_cls(**kwargs)
for field_name, field in unique_fields.items():
ret[field_name] = field
return ret return ret
def _include_additional_options(self, extra_kwargs): def _include_additional_options(self, extra_kwargs):

View File

@ -215,6 +215,33 @@ class TestBooleanHTMLInput:
assert serializer.validated_data == {'archived': False} assert serializer.validated_data == {'archived': False}
class TestCreateOnlyDefault:
def setup(self):
default = serializers.CreateOnlyDefault('2001-01-01')
class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=default)
text = serializers.CharField()
self.Serializer = TestSerializer
def test_create_only_default_is_provided(self):
serializer = self.Serializer(data={'text': 'example'})
assert serializer.is_valid()
assert serializer.validated_data == {
'text': 'example', 'published': '2001-01-01'
}
def test_create_only_default_is_not_provided_on_update(self):
instance = {
'text': 'example', 'published': '2001-01-01'
}
serializer = self.Serializer(instance, data={'text': 'example'})
assert serializer.is_valid()
assert serializer.validated_data == {
'text': 'example',
}
# Tests for field input and output values. # Tests for field input and output values.
# ---------------------------------------- # ----------------------------------------

View File

@ -176,7 +176,7 @@ class TestUniquenessForDateValidation(TestCase):
UniqueForDateSerializer(validators=[<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>]): UniqueForDateSerializer(validators=[<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>]):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100) slug = CharField(max_length=100)
published = DateField() published = DateField(required=True)
""") """)
assert repr(serializer) == expected assert repr(serializer) == expected
@ -215,3 +215,40 @@ class TestUniquenessForDateValidation(TestCase):
'slug': 'existing', 'slug': 'existing',
'published': datetime.date(2000, 1, 1) 'published': datetime.date(2000, 1, 1)
} }
class HiddenFieldUniqueForDateModel(models.Model):
slug = models.CharField(max_length=100, unique_for_date='published')
published = models.DateTimeField(auto_now_add=True)
class TestHiddenFieldUniquenessForDateValidation(TestCase):
def test_repr_date_field_not_included(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = HiddenFieldUniqueForDateModel
fields = ('id', 'slug')
serializer = TestSerializer()
expected = dedent("""
TestSerializer(validators=[<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>]):
id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100)
published = HiddenField(default=CreateOnlyDefault(<function now>))
""")
assert repr(serializer) == expected
def test_repr_date_field_included(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = HiddenFieldUniqueForDateModel
fields = ('id', 'slug', 'published')
serializer = TestSerializer()
expected = dedent("""
TestSerializer(validators=[<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>]):
id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100)
published = DateTimeField(default=CreateOnlyDefault(<function now>), read_only=True)
""")
assert repr(serializer) == expected