From 908f91d8ef13649b6d658981e28ff52296b19f9f Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Mon, 9 Mar 2020 02:43:02 -0700 Subject: [PATCH] Set action for HEAD requests (#7223) * Test viewset action attr * Add 'head' to viewset actions map --- rest_framework/viewsets.py | 7 ++++--- tests/test_viewsets.py | 24 ++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 244c14d39..cad032dd9 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -93,6 +93,10 @@ class ViewSetMixin: def view(request, *args, **kwargs): self = cls(**initkwargs) + + if 'get' in actions and 'head' not in actions: + actions['head'] = actions['get'] + # We also store the mapping of request methods to actions, # so that we can later set the action attribute. # eg. `self.action = 'list'` on an incoming GET request. @@ -104,9 +108,6 @@ class ViewSetMixin: handler = getattr(self, action) setattr(self, method, handler) - if hasattr(self, 'get') and not hasattr(self, 'head'): - self.head = self.get - self.request = request self.args = args self.kwargs = kwargs diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py index f9468f448..1a621c518 100644 --- a/tests/test_viewsets.py +++ b/tests/test_viewsets.py @@ -37,14 +37,18 @@ class ActionViewSet(GenericViewSet): queryset = Action.objects.all() def list(self, request, *args, **kwargs): - return Response() + response = Response() + response.view = self + return response def retrieve(self, request, *args, **kwargs): return Response() @action(detail=False) def list_action(self, request, *args, **kwargs): - raise NotImplementedError + response = Response() + response.view = self + return response @action(detail=False, url_name='list-custom') def custom_list_action(self, request, *args, **kwargs): @@ -155,6 +159,22 @@ class InitializeViewSetsTestCase(TestCase): self.assertNotIn(attribute, dir(bare_view)) self.assertIn(attribute, dir(view)) + def test_viewset_action_attr(self): + view = ActionViewSet.as_view(actions={'get': 'list'}) + + get = view(factory.get('/')) + head = view(factory.head('/')) + assert get.view.action == 'list' + assert head.view.action == 'list' + + def test_viewset_action_attr_for_extra_action(self): + view = ActionViewSet.as_view(actions=dict(ActionViewSet.list_action.mapping)) + + get = view(factory.get('/')) + head = view(factory.head('/')) + assert get.view.action == 'list_action' + assert head.view.action == 'list_action' + class GetExtraActionsTests(TestCase):