Support nullable FKs, with blank=True

This commit is contained in:
Tom Christie 2012-12-07 21:32:39 +00:00
parent a5178e9a36
commit 303bc7cf95
4 changed files with 85 additions and 15 deletions

View File

@ -350,7 +350,13 @@ class RelatedField(WritableField):
return return
value = data.get(field_name) value = data.get(field_name)
into[(self.source or field_name)] = self.from_native(value)
if value is None and not self.blank:
raise ValidationError('Value may not be null')
elif value is None and self.blank:
into[(self.source or field_name)] = None
else:
into[(self.source or field_name)] = self.from_native(value)
class ManyRelatedMixin(object): class ManyRelatedMixin(object):

View File

@ -431,10 +431,14 @@ class ModelSerializer(Serializer):
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
queryset = model_field.rel.to._default_manager kwargs = {
'blank': model_field.blank,
'queryset': model_field.rel.to._default_manager
}
if to_many: if to_many:
return ManyPrimaryKeyRelatedField(queryset=queryset) return ManyPrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(queryset=queryset) return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
""" """
@ -572,9 +576,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to rel = model_field.rel.to
queryset = rel._default_manager
kwargs = { kwargs = {
'queryset': queryset, 'blank': model_field.blank,
'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel) 'view_name': self._get_default_view_name(rel)
} }
if to_many: if to_many:

View File

@ -1,6 +1,7 @@
from django.conf.urls.defaults import patterns, url from django.conf.urls.defaults import patterns, url
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import simplejson as json
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
@ -54,10 +55,12 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):
model = BlogPostComment model = BlogPostComment
serializer_class = BlogPostCommentSerializer serializer_class = BlogPostCommentSerializer
class BlogPostCommentDetail(generics.RetrieveAPIView): class BlogPostCommentDetail(generics.RetrieveAPIView):
model = BlogPostComment model = BlogPostComment
serializer_class = BlogPostCommentSerializer serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView): class BlogPostDetail(generics.RetrieveAPIView):
model = BlogPost model = BlogPost
@ -71,7 +74,7 @@ class AlbumDetail(generics.RetrieveAPIView):
model = Album model = Album
class OptionalRelationDetail(generics.RetrieveAPIView): class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model = OptionalRelationModel model = OptionalRelationModel
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
@ -162,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/manytomany/') request = factory.get('/manytomany/')
response = self.list_view(request).render() response = self.list_view(request)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
@ -171,7 +174,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/manytomany/1/') request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0]) self.assertEquals(response.data, self.data[0])
@ -194,7 +197,7 @@ class TestCreateWithForeignKeys(TestCase):
} }
request = factory.post('/comments/', data=data) request = factory.post('/comments/', data=data)
response = self.create_view(request).render() response = self.create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.count(), 1)
@ -219,7 +222,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
} }
request = factory.post('/photos/', data=data) request = factory.post('/photos/', data=data)
response = self.list_create_view(request).render() response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.count(), 1)
@ -244,6 +247,16 @@ class TestOptionalRelationHyperlinkedView(TestCase):
for non existing relations. for non existing relations.
""" """
request = factory.get('/optionalrelationmodel-detail/1') request = factory.get('/optionalrelationmodel-detail/1')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
def test_put_detail_view(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView with optional relations
should accept None for non existing relations.
"""
response = self.client.put('/optionalrelation/1/',
data=json.dumps(self.data),
content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -49,9 +49,22 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):
model = ForeignKeySource model = ForeignKeySource
# Nullable ForeignKey
class NullableForeignKeySource(models.Model):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
related_name='nullable_sources')
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class PrimaryKeyManyToManyTests(TestCase): class PKManyToManyTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
@ -137,7 +150,7 @@ class PrimaryKeyManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
class PrimaryKeyForeignKeyTests(TestCase): class PKForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -174,7 +187,7 @@ class PrimaryKeyForeignKeyTests(TestCase):
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
serializer.save() serializer.save()
# # Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
@ -184,6 +197,40 @@ class PrimaryKeyForeignKeyTests(TestCase):
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}
]
self.assertEquals(serializer.data, expected)
# reverse foreign keys MUST be read_only # reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear() # In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set. # and cannot be arbitrarily set.