diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6ac6366c7..590a607fa 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -58,11 +58,8 @@ class RetrieveModelMixin: return Response(serializer.data) -class UpdateModelMixin: - """ - Update a model instance. - """ - def update(self, request, *args, **kwargs): +class BaseUpdateModelMixin: + def get_update_response(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) @@ -82,6 +79,29 @@ class UpdateModelMixin: def perform_update(self, serializer): serializer.save() + +class FullUpdateModelMixin(BaseUpdateModelMixin): + """ + Update a model instance. Only allowing 'PUT' method. + """ + def update(self, request, *args, **kwargs): + return self.get_update_response(request, *args, **kwargs) + + +class PartialUpdateModelMixin(BaseUpdateModelMixin): + """ + Update a model instance. Only allowing 'PATCH' method. + """ + + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.get_update_response(request, *args, **kwargs) + + +class UpdateModelMixin(FullUpdateModelMixin, PartialUpdateModelMixin): + """ + Update a model instance. Allowing both 'PATCH' and 'PUT' methods. + """ def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True return self.update(request, *args, **kwargs) diff --git a/tests/test_generics.py b/tests/test_generics.py index 9990389c9..9df4679ca 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -6,8 +6,9 @@ from django.http import Http404 from django.shortcuts import get_object_or_404 from django.test import TestCase -from rest_framework import generics, renderers, serializers, status +from rest_framework import generics, mixins, renderers, serializers, status from rest_framework.exceptions import ErrorDetail +from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework.test import APIRequestFactory from tests.models import ( @@ -63,6 +64,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): serializer_class = BasicSerializer +class PatchOnlyInstanceView(mixins.PartialUpdateModelMixin, GenericAPIView): + queryset = BasicModel.objects.exclude(text='filtered out') + serializer_class = BasicSerializer + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + +class PutOnlyInstanceView(mixins.FullUpdateModelMixin, GenericAPIView): + queryset = BasicModel.objects.exclude(text='filtered out') + serializer_class = BasicSerializer + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): queryset = ForeignKeySource.objects.all() serializer_class = ForeignKeySerializer @@ -188,6 +205,8 @@ class TestInstanceView(TestCase): ] self.view = InstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view() + self.patch_only_view = PatchOnlyInstanceView.as_view() + self.put_only_view = PutOnlyInstanceView.as_view() def test_get_instance_view(self): """ @@ -214,28 +233,30 @@ class TestInstanceView(TestCase): """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - data = {'text': 'foobar'} - request = factory.put('/1', data, format='json') - with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): - response = self.view(request, pk='1').render() - assert response.status_code == status.HTTP_200_OK - assert dict(response.data) == {'id': 1, 'text': 'foobar'} - updated = self.objects.get(id=1) - assert updated.text == 'foobar' + for view in (self.view, self.put_only_view): + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): + response = view(request, pk='1').render() + assert response.status_code == status.HTTP_200_OK + assert dict(response.data) == {'id': 1, 'text': 'foobar'} + updated = self.objects.get(id=1) + assert updated.text == 'foobar' def test_patch_instance_view(self): """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - data = {'text': 'foobar'} - request = factory.patch('/1', data, format='json') + for view in (self.view, self.patch_only_view): + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') - with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): - response = self.view(request, pk=1).render() - assert response.status_code == status.HTTP_200_OK - assert response.data == {'id': 1, 'text': 'foobar'} - updated = self.objects.get(id=1) - assert updated.text == 'foobar' + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): + response = view(request, pk=1).render() + assert response.status_code == status.HTTP_200_OK + assert response.data == {'id': 1, 'text': 'foobar'} + updated = self.objects.get(id=1) + assert updated.text == 'foobar' def test_delete_instance_view(self): """