Use either PrimaryKeyRelatedField or ManyPrimaryKeyRelatedField as appropriate (fixes test)

This commit is contained in:
Tom Christie 2012-10-03 12:16:30 +01:00
parent cab3b2f3f8
commit 58c1263267
3 changed files with 9 additions and 6 deletions

View File

@ -221,13 +221,13 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
try:
obj = obj.serializable_value(self.source or field_name)
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
return self.to_native(obj.pk)
# Forward relationship
return self.to_native(obj)
return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
@ -237,13 +237,13 @@ class PrimaryKeyRelatedField(RelatedField):
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def field_to_native(self, obj, field_name):
try:
obj = obj.serializable_value(self.source or field_name)
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
obj = getattr(obj, self.source or field_name)
return [self.to_native(item.pk) for item in obj.all()]
queryset = getattr(obj, self.source or field_name)
return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
return [self.to_native(item.pk) for item in obj.all()]
return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:

View File

@ -351,6 +351,8 @@ class ModelSerializer(RelatedField, Serializer):
"""
Creates a default instance of a flat relational field.
"""
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyPrimaryKeyRelatedField()
return PrimaryKeyRelatedField()
def get_field(self, model_field):

View File

@ -201,6 +201,7 @@ class ManyToManyTests(TestCase):
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
# def test_deserialization_for_update(self):
# serializer = self.serializer_class(self.data, instance=self.instance)
# expected = self.instance