diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e8e6735a5..f8da58792 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -93,7 +93,7 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): class Meta(object): pass @@ -118,6 +118,7 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._siblings = [] ##### # Methods to determine which fields to use when (de)serializing objects. @@ -276,7 +277,11 @@ class BaseSerializer(Field): """ if hasattr(data, '__iter__') and not isinstance(data, dict): # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) + for item in data: + sibling = copy.deepcopy(self) + self._siblings.append(sibling) + sibling.object = sibling.from_native(item, None) + return [sibling.object for sibling in self._siblings] self._errors = {} if data is not None or files is not None: @@ -361,6 +366,24 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + # TODO handle partial option + # TODO handle errors + + # deserialize the nested object + try: + native = data[field_name] + except KeyError: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + self.object = self.from_native(native, files) + into[self.source or field_name] = self + def get_default_fields(self): """ Return all the fields that should be serialized for the model. @@ -498,28 +521,27 @@ class ModelSerializer(Serializer): self.m2m_data = {} self.related_data = {} + # Reverse fk relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) + + # Reverse m2m relations + for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.m2m_data[field_name] = attrs.pop(field_name) + + # Forward m2m relations + for field in self.opts.model._meta.many_to_many: + if field.name in attrs: + self.m2m_data[field.name] = attrs.pop(field.name) + if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) - else: - # Reverse fk relations - for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): - field_name = obj.field.related_query_name() - if field_name in attrs: - self.related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): - field_name = obj.field.related_query_name() - if field_name in attrs: - self.m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in self.opts.model._meta.many_to_many: - if field.name in attrs: - self.m2m_data[field.name] = attrs.pop(field.name) - instance = self.opts.model(**attrs) try: @@ -530,10 +552,17 @@ class ModelSerializer(Serializer): return instance - def save(self, save_m2m=True): + def save(self, save_m2m=True, parent=None, fk_field=None): """ Save the deserialized object and return it. """ + if hasattr(self.object, '__iter__'): + for obj in self._siblings: + obj.save(parent=parent, fk_field=fk_field) + return self.object + + if parent and fk_field: + setattr(self.object, fk_field, parent) self.object.save() if getattr(self, 'm2m_data', None) and save_m2m: @@ -543,7 +572,11 @@ class ModelSerializer(Serializer): if getattr(self, 'related_data', None): for accessor_name, object_list in self.related_data.items(): - setattr(self.object, accessor_name, object_list) + if isinstance(object_list, ModelSerializer): + fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name + object_list.save(parent=self.object, fk_field=fk_field) + else: + setattr(self.object, accessor_name, object_list) self.related_data = {} return self.object diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 240394105..4e42324ff 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -113,9 +113,9 @@ class HyperlinkedManyToManyTests(TestCase): data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} instance = ManyToManySource.objects.get(pk=1) serializer = ManyToManySourceSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertTrue(serializer.is_valid()) serializer.save() + self.assertEquals(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() @@ -132,8 +132,8 @@ class HyperlinkedManyToManyTests(TestCase): instance = ManyToManyTarget.objects.get(pk=1) serializer = ManyToManyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEquals(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() @@ -238,9 +238,9 @@ class HyperlinkedForeignKeyTests(TestCase): data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} instance = ForeignKeyTarget.objects.get(pk=2) serializer = ForeignKeyTargetSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertTrue(serializer.is_valid()) serializer.save() + self.assertEquals(serializer.data, data) # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all() diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index b11473780..e6b50447d 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -80,6 +80,21 @@ class ReverseForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) + def test_reverse_foreign_key_create(self): + target = ForeignKeyTarget.objects.get(name='target-2') + data = {'sources': [{'name': u'source-4', 'target': 2}], 'name': u'target-2a'} + expected = {'id': 2, 'name': u'target-2a', 'sources': [{'id': 4, 'name': u'source-4', 'target': 2}]} + serializer = ForeignKeyTargetSerializer(target, data=data, partial=True) + # serializer.is_valid() + # print serializer.errors + self.assertTrue(serializer.is_valid()) + serializer.save() + # Ensure target 2 has new source and everything else is as expected + target = ForeignKeyTarget.objects.get(name='target-2a') + serializer = ForeignKeyTargetSerializer(target) + self.assertEquals(serializer.data, expected) + + class NestedNullableForeignKeyTests(TestCase): def setUp(self): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index 01109ef95..b832fc585 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -99,8 +99,8 @@ class PKManyToManyTests(TestCase): instance = ManyToManySource.objects.get(pk=1) serializer = ManyToManySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEquals(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() @@ -117,8 +117,8 @@ class PKManyToManyTests(TestCase): instance = ManyToManyTarget.objects.get(pk=1) serializer = ManyToManyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEquals(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() @@ -221,8 +221,8 @@ class PKForeignKeyTests(TestCase): instance = ForeignKeyTarget.objects.get(pk=2) serializer = ForeignKeyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEquals(serializer.data, data) # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all()