From e82890202336d76f5b2f31d556d658abb11bc443 Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Sat, 8 Dec 2012 12:56:34 +0100 Subject: [PATCH] improved mixins saving objects when creating, with force_insert when updating, with force_update --- rest_framework/mixins.py | 7 +++++-- rest_framework/serializers.py | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 1edcfa5c9..3c969eb39 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,7 +18,7 @@ class CreateModelMixin(object): serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(force_insert=True) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -93,7 +93,10 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + if created: + self.object = serializer.save(force_insert=True) + else: + self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK return Response(serializer.data, status=status_code) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 13c41a4bd..6d9389c11 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -339,11 +339,11 @@ class BaseSerializer(Field): self._data = self.to_native(self.object) return self._data - def save(self): + def save(self,**kwargs): """ Save the deserialized object and return it. """ - self.object.save() + self.object.save(**kwargs) return self.object @@ -519,18 +519,18 @@ class ModelSerializer(Serializer): self.m2m_data[field.name] = attrs.pop(field.name) return self.opts.model(**attrs) - def save(self, save_m2m=True): + def save(self, save_m2m=True, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + obj = super(ModelSerializer, self).save(**kwargs) if getattr(self, 'm2m_data', None) and save_m2m: for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} - return self.object + return obj class HyperlinkedModelSerializerOptions(ModelSerializerOptions):