From f32398953d563fb8c57b8e1a57323eb2990d5302 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 16 May 2018 09:29:42 -0400 Subject: [PATCH] Add method mapping to ViewSet actions --- rest_framework/decorators.py | 63 ++++++++++++++++++++++++++++++++++- rest_framework/routers.py | 3 +- rest_framework/viewsets.py | 2 +- tests/test_decorators.py | 64 +++++++++++++++++++++++++++++++++--- tests/test_routers.py | 34 +++++++++++++++++-- tests/test_schemas.py | 35 ++++++++++++++------ 6 files changed, 180 insertions(+), 21 deletions(-) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 9f6b8101c..60078947f 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -146,7 +146,8 @@ def action(methods=None, detail=None, name=None, url_path=None, url_name=None, * ) def decorator(func): - func.bind_to_methods = methods + func.mapping = MethodMapper(func, methods) + func.detail = detail func.name = name if name else pretty_name(func.__name__) func.url_path = url_path if url_path else func.__name__ @@ -156,10 +157,70 @@ def action(methods=None, detail=None, name=None, url_path=None, url_name=None, * 'name': func.name, 'description': func.__doc__ or None }) + return func return decorator +class MethodMapper(dict): + """ + Enables mapping HTTP methods to different ViewSet methods for a single, + logical action. + + Example usage: + + class MyViewSet(ViewSet): + + @action(detail=False) + def example(self, request, **kwargs): + ... + + @example.mapping.post + def create_example(self, request, **kwargs): + ... + """ + + def __init__(self, action, methods): + self.action = action + for method in methods: + self[method] = self.action.__name__ + + def _map(self, method, func): + assert method not in self, ( + "Method '%s' has already been mapped to '.%s'." % (method, self[method])) + assert func.__name__ != self.action.__name__, ( + "Method mapping does not behave like the property decorator. You " + "cannot use the same method name for each mapping declaration.") + + self[method] = func.__name__ + + return func + + def get(self, func): + return self._map('get', func) + + def post(self, func): + return self._map('post', func) + + def put(self, func): + return self._map('put', func) + + def patch(self, func): + return self._map('patch', func) + + def delete(self, func): + return self._map('delete', func) + + def head(self, func): + return self._map('head', func) + + def options(self, func): + return self._map('options', func) + + def trace(self, func): + return self._map('trace', func) + + def detail_route(methods=None, **kwargs): """ Used to mark a method on a ViewSet that should be routed for detail requests. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 281bbde8a..52b2b7cc6 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -208,8 +208,7 @@ class SimpleRouter(BaseRouter): return Route( url=route.url.replace('{url_path}', url_path), - mapping={http_method: action.__name__ - for http_method in action.bind_to_methods}, + mapping=action.mapping, name=route.name.replace('{url_name}', action.url_name), detail=route.detail, initkwargs=initkwargs, diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index c4cd85d0b..412475351 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -31,7 +31,7 @@ from rest_framework.reverse import reverse def _is_extra_action(attr): - return hasattr(attr, 'bind_to_methods') + return hasattr(attr, 'mapping') class ViewSetMixin(object): diff --git a/tests/test_decorators.py b/tests/test_decorators.py index c24287600..7568513f3 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -177,7 +177,7 @@ class ActionDecoratorTestCase(TestCase): def test_action(request): """Description""" - assert test_action.bind_to_methods == ['get'] + assert test_action.mapping == {'get': 'test_action'} assert test_action.detail is True assert test_action.name == 'Test action' assert test_action.url_path == 'test_action' @@ -191,15 +191,69 @@ class ActionDecoratorTestCase(TestCase): with pytest.raises(AssertionError) as excinfo: @action() def test_action(request): - pass + raise NotImplementedError assert str(excinfo.value) == "@action() missing required argument: 'detail'" + def test_method_mapping_http_methods(self): + # All HTTP methods should be mappable + @action(detail=False, methods=[]) + def test_action(): + raise NotImplementedError + + for name in APIView.http_method_names: + def method(): + raise NotImplementedError + + # Python 2.x compatibility - cast __name__ to str + method.__name__ = str(name) + getattr(test_action.mapping, name)(method) + + # ensure the mapping returns the correct method name + for name in APIView.http_method_names: + assert test_action.mapping[name] == name + + def test_method_mapping(self): + @action(detail=False) + def test_action(request): + raise NotImplementedError + + @test_action.mapping.post + def test_action_post(request): + raise NotImplementedError + + # The secondary handler methods should not have the action attributes + for name in ['mapping', 'detail', 'name', 'url_path', 'url_name', 'kwargs']: + assert hasattr(test_action, name) and not hasattr(test_action_post, name) + + def test_method_mapping_already_mapped(self): + @action(detail=True) + def test_action(request): + raise NotImplementedError + + msg = "Method 'get' has already been mapped to '.test_action'." + with self.assertRaisesMessage(AssertionError, msg): + @test_action.mapping.get + def test_action_get(request): + raise NotImplementedError + + def test_method_mapping_overwrite(self): + @action(detail=True) + def test_action(): + raise NotImplementedError + + msg = ("Method mapping does not behave like the property decorator. You " + "cannot use the same method name for each mapping declaration.") + with self.assertRaisesMessage(AssertionError, msg): + @test_action.mapping.post + def test_action(): + raise NotImplementedError + def test_detail_route_deprecation(self): with pytest.warns(PendingDeprecationWarning) as record: @detail_route() def view(request): - pass + raise NotImplementedError assert len(record) == 1 assert str(record[0].message) == ( @@ -212,7 +266,7 @@ class ActionDecoratorTestCase(TestCase): with pytest.warns(PendingDeprecationWarning) as record: @list_route() def view(request): - pass + raise NotImplementedError assert len(record) == 1 assert str(record[0].message) == ( @@ -226,7 +280,7 @@ class ActionDecoratorTestCase(TestCase): with pytest.warns(PendingDeprecationWarning): @list_route(url_path='foo_bar') def view(request): - pass + raise NotImplementedError assert view.url_path == 'foo_bar' assert view.url_name == 'foo-bar' diff --git a/tests/test_routers.py b/tests/test_routers.py index 2189d1c2b..8f52d217f 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -7,7 +7,7 @@ from django.conf.urls import include, url from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase, override_settings -from django.urls import resolve +from django.urls import resolve, reverse from rest_framework import permissions, serializers, viewsets from rest_framework.compat import get_regex_pattern @@ -107,8 +107,23 @@ class BasicViewSet(viewsets.ViewSet): def action2(self, request, *args, **kwargs): return Response({'method': 'action2'}) + @action(methods=['post'], detail=True) + def action3(self, request, pk, *args, **kwargs): + return Response({'post': pk}) + + @action3.mapping.delete + def action3_delete(self, request, pk, *args, **kwargs): + return Response({'delete': pk}) + + +class TestSimpleRouter(URLPatternsTestCase, TestCase): + router = SimpleRouter() + router.register('basics', BasicViewSet, base_name='basic') + + urlpatterns = [ + url(r'^api/', include(router.urls)), + ] -class TestSimpleRouter(TestCase): def setUp(self): self.router = SimpleRouter() @@ -127,6 +142,21 @@ class TestSimpleRouter(TestCase): 'delete': 'action2', } + assert routes[2].url == '^{prefix}/{lookup}/action3{trailing_slash}$' + assert routes[2].mapping == { + 'post': 'action3', + 'delete': 'action3_delete', + } + + def test_multiple_action_handlers(self): + # Standard action + response = self.client.post(reverse('basic-action3', args=[1])) + assert response.data == {'post': '1'} + + # Additional handler registered with MethodMapper + response = self.client.delete(reverse('basic-action3', args=[1])) + assert response.data == {'delete': '1'} + class TestRootView(URLPatternsTestCase, TestCase): urlpatterns = [ diff --git a/tests/test_schemas.py b/tests/test_schemas.py index f929fece5..e4a7c8646 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -75,29 +75,35 @@ class ExampleViewSet(ModelViewSet): """ A description of custom action. """ - return super(ExampleSerializer, self).retrieve(self, request) + raise NotImplementedError @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField) def custom_action_with_dict_field(self, request, pk): """ A custom action using a dict field in the serializer. """ - return super(ExampleSerializer, self).retrieve(self, request) + raise NotImplementedError @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields) def custom_action_with_list_fields(self, request, pk): """ A custom action using both list field and list serializer in the serializer. """ - return super(ExampleSerializer, self).retrieve(self, request) + raise NotImplementedError @action(detail=False) def custom_list_action(self, request): - return super(ExampleViewSet, self).list(self, request) + raise NotImplementedError @action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer) def custom_list_action_multiple_methods(self, request): - return super(ExampleViewSet, self).list(self, request) + """Custom description.""" + raise NotImplementedError + + @custom_list_action_multiple_methods.mapping.delete + def custom_list_action_multiple_methods_delete(self, request): + """Deletion description.""" + raise NotImplementedError def get_serializer(self, *args, **kwargs): assert self.request @@ -147,7 +153,8 @@ class TestRouterGeneratedSchema(TestCase): 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', - action='get' + action='get', + description='Custom description.', ) }, 'read': coreapi.Link( @@ -238,12 +245,19 @@ class TestRouterGeneratedSchema(TestCase): 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', - action='get' + action='get', + description='Custom description.', ), 'create': coreapi.Link( url='/example/custom_list_action_multiple_methods/', - action='post' - ) + action='post', + description='Custom description.', + ), + 'delete': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='delete', + description='Deletion description.', + ), }, 'update': coreapi.Link( url='/example/{id}/', @@ -526,7 +540,8 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example1/custom_list_action_multiple_methods/', - action='get' + action='get', + description='Custom description.', ) }, 'read': coreapi.Link(