diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6eb9c3e11..e8e6735a5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -496,19 +496,26 @@ class ModelSerializer(Serializer): Restore the model instance. """ self.m2m_data = {} + self.related_data = {} if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) else: - # Reverse relations + # Reverse fk relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) + + # Reverse m2m relations for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: self.m2m_data[field_name] = attrs.pop(field_name) - # Forward relations + # Forward m2m relations for field in self.opts.model._meta.many_to_many: if field.name in attrs: self.m2m_data[field.name] = attrs.pop(field.name) @@ -534,6 +541,11 @@ class ModelSerializer(Serializer): setattr(self.object, accessor_name, object_list) self.m2m_data = {} + if getattr(self, 'related_data', None): + for accessor_name, object_list in self.related_data.items(): + setattr(self.object, accessor_name, object_list) + self.related_data = {} + return self.object diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index e5391f1b4..c2e61279a 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -38,7 +38,7 @@ class ForeignKeySource(models.Model): class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.ManyPrimaryKeyRelatedField(read_only=True) + sources = serializers.ManyPrimaryKeyRelatedField() class Meta: model = ForeignKeyTarget @@ -235,24 +235,23 @@ class PKForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) - # TODO: See #511 - # def test_reverse_foreign_key_create(self): - # data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} - # serializer = ForeignKeyTargetSerializer(data=data) - # self.assertTrue(serializer.is_valid()) - # obj = serializer.save() - # self.assertEquals(serializer.data, data) - # self.assertEqual(obj.name, u'target-3') + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-3') - # # Ensure target 4 is added, and everything else is as expected - # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset) - # expected = [ - # {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - # {'id': 2, 'name': u'target-2', 'sources': []}, - # {'id': 3, 'name': u'target-3', 'sources': [1, 3]}, - # ] - # self.assertEquals(serializer.data, expected) + # Ensure target 4 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [2]}, + {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 3, 'name': u'target-3', 'sources': [1, 3]}, + ] + self.assertEquals(serializer.data, expected) def test_foreign_key_update_with_invalid_null(self): data = {'id': 1, 'name': u'source-1', 'target': None}