Tests and fix for #666.

Closes #666.
This commit is contained in:
Tom Christie 2013-02-22 22:02:42 +00:00
parent d44eb20942
commit bc87bf13b4
3 changed files with 64 additions and 6 deletions

View File

@ -12,6 +12,28 @@ from rest_framework.response import Response
from rest_framework.request import clone_request from rest_framework.request import clone_request
def _get_validation_exclusions(obj, pk=None, slug_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
include = []
if pk:
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
include.append(slug_field)
return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object): class CreateModelMixin(object):
""" """
Create a model instance. Create a model instance.
@ -117,18 +139,20 @@ class UpdateModelMixin(object):
""" """
# pk and/or slug attributes are implicit in the URL. # pk and/or slug attributes are implicit in the URL.
pk = self.kwargs.get(self.pk_url_kwarg, None) pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
slug_field = slug and self.get_slug_field() or None
if pk: if pk:
setattr(obj, 'pk', pk) setattr(obj, 'pk', pk)
slug = self.kwargs.get(self.slug_url_kwarg, None)
if slug: if slug:
slug_field = self.get_slug_field()
setattr(obj, slug_field, slug) setattr(obj, slug_field, slug)
# Ensure we clean the attributes so that we don't eg return integer # Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg. # pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'): if hasattr(obj, 'full_clean'):
obj.full_clean() exclude = _get_validation_exclusions(obj, pk, slug_field)
obj.full_clean(exclude)
class DestroyModelMixin(object): class DestroyModelMixin(object):

View File

@ -158,6 +158,7 @@ class BaseSerializer(Field):
# If 'fields' is specified, use those fields, in that order. # If 'fields' is specified, use those fields, in that order.
if self.opts.fields: if self.opts.fields:
assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
new = SortedDict() new = SortedDict()
for key in self.opts.fields: for key in self.opts.fields:
new[key] = ret[key] new[key] = ret[key]
@ -165,6 +166,7 @@ class BaseSerializer(Field):
# Remove anything in 'exclude' # Remove anything in 'exclude'
if self.opts.exclude: if self.opts.exclude:
assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude: for key in self.opts.exclude:
ret.pop(key, None) ret.pop(key, None)
@ -421,8 +423,8 @@ class ModelSerializer(Serializer):
cls = self.opts.model cls = self.opts.model
opts = get_concrete_model(cls)._meta opts = get_concrete_model(cls)._meta
pk_field = opts.pk pk_field = opts.pk
while pk_field.rel: # while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk # pk_field = pk_field.rel.to._meta.pk
fields = [pk_field] fields = [pk_field]
fields += [field for field in opts.fields if field.serialize] fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize] fields += [field for field in opts.many_to_many if field.serialize]

View File

@ -319,7 +319,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEquals(created.content, 'foobar') self.assertEquals(created.content, 'foobar')
# Test for particularly ugly reression with m2m in browseable API # Test for particularly ugly regression with m2m in browseable API
class ClassB(models.Model): class ClassB(models.Model):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
@ -350,3 +350,35 @@ class TestM2MBrowseableAPI(TestCase):
view = ExampleView().as_view() view = ExampleView().as_view()
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
# Regression for #666
class ValidationModel(models.Model):
blank_validated_field = models.CharField(max_length=255)
class ValidationModelSerializer(serializers.ModelSerializer):
class Meta:
model = ValidationModel
fields = ('blank_validated_field',)
read_only_fields = ('blank_validated_field',)
class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
model = ValidationModel
serializer_class = ValidationModelSerializer
class TestPreSaveValidationExclusions(TestCase):
def test_pre_save_validation_exclusions(self):
"""
Somewhat weird test case to ensure that we don't perform model
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', json.dumps({}),
content_type='application/json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)