diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index d391755c7..e3c7cf037 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -20,8 +20,7 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): - self.object = serializer.object - self.object.save() + self.object = serializer.save() return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -75,8 +74,7 @@ class UpdateModelMixin(object): self.object = self.get_object() serializer = self.get_serializer(data=request.DATA, instance=self.object) if serializer.is_valid(): - self.object = serializer.object - self.object.save() + self.object = serializer.save() return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5935bce53..683b9efc8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -275,6 +275,13 @@ class BaseSerializer(Field): self._data = self.to_native(self.object) return self._data + def save(self): + """ + Save the deserialized object and return it. + """ + self.object.save() + return self.object + class Serializer(BaseSerializer): __metaclass__ = SerializerMetaclass @@ -379,3 +386,10 @@ class ModelSerializer(RelatedField, Serializer): if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) return DeserializedObject(self.opts.model(**attrs), m2m_data) + + def save(self): + """ + Save the deserialized object and return it. + """ + self.object.save() + return self.object.object diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 7de79f957..b7a9ae99a 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -135,10 +135,10 @@ class ManyToManyTests(TestCase): data = {'rel': [self.anchor.id]} serializer = self.serializer_class(data) self.assertEquals(serializer.is_valid(), True) - serializer.object.save() - obj = serializer.object.object - self.assertEquals(obj.pk, 2) - self.assertEquals(list(obj.rel.all()), [self.anchor]) + instance = serializer.save() + self.assertEquals(len(ManyToManyModel.objects.all()), 2) + self.assertEquals(instance.pk, 2) + self.assertEquals(list(instance.rel.all()), [self.anchor]) # self.assertFalse(serializer.object is expected) # def test_deserialization_for_update(self):