diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index 0170256f2..94292b746 100644 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -222,6 +222,16 @@ If an object is updated this returns a `200 OK` response, with a serialized repr If the request data provided for updating the object was invalid, a `400 Bad Request` response will be returned, with the error details as the body of the response. +## PartialUpdateModelMixin + +Similar to `UpdateModelMixin` except that it only includes the `partial_update` +(i.e. `PATCH`) capability and not the `update` (i.e. `PUT`) capability. + +## FullUpdateModelMixin + +Similar to `UpdateModelMixin` except that it only includes the `update` +(i.e. `PUT`) capability and not the `partial_update` (i.e. `PATCH`) capability. + ## DestroyModelMixin Provides a `.destroy(request, *args, **kwargs)` method, that implements deletion of an existing model instance. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index f3695e665..04364f322 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -58,11 +58,8 @@ class RetrieveModelMixin(object): return Response(serializer.data) -class UpdateModelMixin(object): - """ - Update a model instance. - """ - def update(self, request, *args, **kwargs): +class _UpdateModelMixin(object): + def do_update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) @@ -79,9 +76,30 @@ class UpdateModelMixin(object): def perform_update(self, serializer): serializer.save() + +class PartialUpdateModelMixin(_UpdateModelMixin): + """ + Partial update a model instance. + """ def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.do_update(request, *args, **kwargs) + + +class FullUpdateModelMixin(object): + """ + Full update a model instance. + """ + def update(self, request, *args, **kwargs): + kwargs['partial'] = False + return self.do_update(request, *args, **kwargs) + + +class UpdateModelMixin(FullUpdateModelMixin, PartialUpdateModelMixin): + """ + Update a model instance fully or partially. + """ + pass class DestroyModelMixin(object): diff --git a/tests/test_routers.py b/tests/test_routers.py index fee39b2b3..d3f92b9fc 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -9,7 +9,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase, override_settings -from rest_framework import permissions, serializers, viewsets +from rest_framework import mixins, permissions, serializers, viewsets from rest_framework.compat import include from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response @@ -132,10 +132,6 @@ class BasicViewSet(viewsets.ViewSet): return Response({'method': 'link2'}) -class TestSimpleRouter(TestCase): - def setUp(self): - self.router = SimpleRouter() - def test_link_and_action_decorator(self): routes = self.router.get_routes(BasicViewSet) decorator_routes = routes[2:] @@ -155,6 +151,45 @@ class TestSimpleRouter(TestCase): assert route.mapping[method] == endpoint +class TestUpdateViewSets(TestCase): + """ + Verify that the *UpdateModelMixin mixin classes expose the correct methods + for a router to pick up. + """ + def test_update_viewsets(self): + class PartialUpdateViewSet( + mixins.PartialUpdateModelMixin, + viewsets.GenericViewSet, + ): + pass + + + class FullUpdateViewSet( + mixins.FullUpdateModelMixin, + viewsets.GenericViewSet, + ): + pass + + + class UpdateViewSet( + mixins.FullUpdateModelMixin, + viewsets.GenericViewSet, + ): + pass + + + for cls, actions in ( + (PartialUpdateViewSet, {'patch', 'partial_update'}), + (FullUpdateViewSet, {'put', 'update'}), + (UpdateViewSet, {'patch': 'partial_update', 'put', 'update'}), + ): + router = SimpleRouter() + router.register('test', cls, 'basename') + urls = router.get_urls() + assert len(urls) == 1 + assert urls[0].callback.actions == actions + + @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase): def test_retrieve_namespaced_root(self):