From acc8c1faa4f85dda00723d755e56bb3c980dbc75 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 13 Mar 2013 20:40:39 +0000 Subject: [PATCH] force_insert, force_update arguments. Closes #484. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Confirmed by `assertNumQueries(…)` in tests. --- docs/topics/release-notes.md | 4 ++++ rest_framework/mixins.py | 6 ++++-- rest_framework/serializers.py | 14 +++++++------- rest_framework/tests/generics.py | 10 +++++----- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 5a96c09cc..c45fff880 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,10 @@ You can determine your currently installed version using `pip freeze`: ## 2.2.x series +### Master + +* `Serializer.save()` now supports arbitrary keyword args which are passed through to the object `.save()` method. Mixins use `force_insert` and `force_update` where appropriate, resulting in one less database query. + ### 2.2.4 **Date**: 13th March 2013 diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 8e4012049..7d9a6e654 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -44,7 +44,7 @@ class CreateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(force_insert=True) self.post_save(self.object, created=True) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, @@ -119,9 +119,11 @@ class UpdateModelMixin(object): # we have relevant permissions, as if this was a POST request. self.check_permissions(clone_request(request, 'POST')) created = True + save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED else: created = False + save_kwargs = {'force_update': True} success_status_code = status.HTTP_200_OK serializer = self.get_serializer(self.object, data=request.DATA, @@ -129,7 +131,7 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(**save_kwargs) self.post_save(self.object, created=created) return Response(serializer.data, status=success_status_code) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cd2bb8f1f..4fe857a61 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -391,17 +391,17 @@ class BaseSerializer(Field): return self._data - def save_object(self, obj): - obj.save() + def save_object(self, obj, **kwargs): + obj.save(**kwargs) - def save(self): + def save(self, **kwargs): """ Save the deserialized object and return it. """ if isinstance(self.object, list): - [self.save_object(item) for item in self.object] + [self.save_object(item, **kwargs) for item in self.object] else: - self.save_object(self.object) + self.save_object(self.object, **kwargs) return self.object @@ -621,11 +621,11 @@ class ModelSerializer(Serializer): if instance: return self.full_clean(instance) - def save_object(self, obj): + def save_object(self, obj, **kwargs): """ Save the deserialized object and return it. """ - obj.save() + obj.save(**kwargs) if getattr(self, 'm2m_data', None): for accessor_name, object_list in self.m2m_data.items(): diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 1837898b5..f564890cc 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -184,7 +184,7 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -199,7 +199,7 @@ class TestInstanceView(TestCase): request = factory.patch('/1', json.dumps(content), content_type='application/json') - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -248,7 +248,7 @@ class TestInstanceView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -264,7 +264,7 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -280,7 +280,7 @@ class TestInstanceView(TestCase): # pk fields can not be created on demand, only the database can set the pk for a new object request = factory.put('/5', json.dumps(content), content_type='application/json') - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) new_obj = self.objects.get(pk=5)