foreign key tests

This commit is contained in:
Tom Christie 2012-11-02 20:53:33 +00:00
parent e84ce60a0d
commit 6eaec7a0ec
2 changed files with 132 additions and 20 deletions

View File

@ -383,7 +383,8 @@ class PrimaryKeyRelatedField(RelatedField):
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise ValidationError('Invalid hyperlink - object does not exist.') msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
raise ValidationError(msg)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
try: try:
@ -430,6 +431,16 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
# Forward relationship # Forward relationship
return [self.to_native(item.pk) for item in queryset.all()] return [self.to_native(item.pk) for item in queryset.all()]
def from_native(self, data):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
raise ValidationError(msg)
### Slug relationships ### Slug relationships

View File

@ -3,26 +3,50 @@ from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
class Target(models.Model): # ManyToMany
class ManyToManyTarget(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
class Source(models.Model): class ManyToManySource(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
targets = models.ManyToManyField(Target, related_name='sources') targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
class TargetSerializer(serializers.ModelSerializer): class ManyToManyTargetSerializer(serializers.ModelSerializer):
sources = serializers.ManyPrimaryKeyRelatedField() sources = serializers.ManyPrimaryKeyRelatedField(queryset=ManyToManySource.objects.all())
class Meta: class Meta:
fields = ('id', 'name', 'sources') model = ManyToManyTarget
model = Target
class SourceSerializer(serializers.ModelSerializer): class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Source model = ManyToManySource
# ForeignKey
class ForeignKeyTarget(models.Model):
name = models.CharField(max_length=100)
class ForeignKeySource(models.Model):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.ManyPrimaryKeyRelatedField(queryset=ForeignKeySource.objects.all())
class Meta:
model = ForeignKeyTarget
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
@ -30,15 +54,16 @@ class SourceSerializer(serializers.ModelSerializer):
class PrimaryKeyManyToManyTests(TestCase): class PrimaryKeyManyToManyTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = Target(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
target.save() target.save()
source = Source(name='source-%d' % idx) source = ManyToManySource(name='source-%d' % idx)
source.save() source.save()
for target in Target.objects.all(): for target in ManyToManyTarget.objects.all():
source.targets.add(target) source.targets.add(target)
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
serializer = SourceSerializer(instance=Source.objects.all()) queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(instance=queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1]}, {'id': 1, 'name': u'source-1', 'targets': [1]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
@ -47,7 +72,8 @@ class PrimaryKeyManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
serializer = TargetSerializer(instance=Target.objects.all()) queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(instance=queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
@ -57,12 +83,15 @@ class PrimaryKeyManyToManyTests(TestCase):
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
serializer = SourceSerializer(data, instance=Source.objects.get(pk=1)) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(data, instance=instance)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
serializer = SourceSerializer(instance=Source.objects.all()) queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(instance=queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
@ -71,16 +100,88 @@ class PrimaryKeyManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': u'target-0', 'sources': [1]} data = {'id': 1, 'name': u'target-1', 'sources': [1]}
serializer = TargetSerializer(data, instance=Target.objects.get(pk=1)) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(data, instance=instance)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
serializer.save()
# Ensure target 1 is updated, and everything else is as expected # Ensure target 1 is updated, and everything else is as expected
serializer = TargetSerializer(instance=Target.objects.all()) queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(instance=queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1]}, {'id': 1, 'name': u'target-1', 'sources': [1]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]} {'id': 3, 'name': u'target-3', 'sources': [3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
class PrimaryKeyForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(instance=queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}
]
self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(instance=queryset)
expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': []},
]
self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self):
data = {'id': 1, 'name': u'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(data, instance=instance)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
serializer.save()
# # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(instance=queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': 2},
{'id': 2, 'name': u'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}
]
self.assertEquals(serializer.data, expected)
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': u'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(data, instance=instance)
# self.assertTrue(serializer.is_valid())
# self.assertEquals(serializer.data, data)
# serializer.save()
# # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(instance=queryset)
# expected = [
# {'id': 1, 'name': u'target-1', 'sources': [1]},
# {'id': 2, 'name': u'target-2', 'sources': []},
# ]
# self.assertEquals(serializer.data, expected)