Check extra action func.__name__ (#7098)

This commit is contained in:
Ryan P Kilby 2020-08-05 21:29:47 -07:00 committed by GitHub
parent 0d2bbd3177
commit 1e383f103a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 1 deletions

View File

@ -33,6 +33,15 @@ def _is_extra_action(attr):
return hasattr(attr, 'mapping') and isinstance(attr.mapping, MethodMapper) return hasattr(attr, 'mapping') and isinstance(attr.mapping, MethodMapper)
def _check_attr_name(func, name):
assert func.__name__ == name, (
'Expected function (`{func.__name__}`) to match its attribute name '
'(`{name}`). If using a decorator, ensure the inner function is '
'decorated with `functools.wraps`, or that `{func.__name__}.__name__` '
'is otherwise set to `{name}`.').format(func=func, name=name)
return func
class ViewSetMixin: class ViewSetMixin:
""" """
This is the magic. This is the magic.
@ -164,7 +173,9 @@ class ViewSetMixin:
""" """
Get the methods that are marked as an extra ViewSet `@action`. Get the methods that are marked as an extra ViewSet `@action`.
""" """
return [method for _, method in getmembers(cls, _is_extra_action)] return [_check_attr_name(method, name)
for name, method
in getmembers(cls, _is_extra_action)]
def get_extra_action_url_map(self): def get_extra_action_url_map(self):
""" """

View File

@ -1,4 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from functools import wraps
import pytest import pytest
from django.conf.urls import include, url from django.conf.urls import include, url
@ -33,6 +34,13 @@ class Action(models.Model):
pass pass
def decorate(fn):
@wraps(fn)
def wrapper(self, request, *args, **kwargs):
return fn(self, request, *args, **kwargs)
return wrapper
class ActionViewSet(GenericViewSet): class ActionViewSet(GenericViewSet):
queryset = Action.objects.all() queryset = Action.objects.all()
@ -68,6 +76,16 @@ class ActionViewSet(GenericViewSet):
def unresolvable_detail_action(self, request, *args, **kwargs): def unresolvable_detail_action(self, request, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@action(detail=False)
@decorate
def wrapped_list_action(self, request, *args, **kwargs):
raise NotImplementedError
@action(detail=True)
@decorate
def wrapped_detail_action(self, request, *args, **kwargs):
raise NotImplementedError
class ActionNamesViewSet(GenericViewSet): class ActionNamesViewSet(GenericViewSet):
@ -191,6 +209,8 @@ class GetExtraActionsTests(TestCase):
'detail_action', 'detail_action',
'list_action', 'list_action',
'unresolvable_detail_action', 'unresolvable_detail_action',
'wrapped_detail_action',
'wrapped_list_action',
] ]
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@ -204,9 +224,35 @@ class GetExtraActionsTests(TestCase):
'detail_action', 'detail_action',
'list_action', 'list_action',
'unresolvable_detail_action', 'unresolvable_detail_action',
'wrapped_detail_action',
'wrapped_list_action',
] ]
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_attr_name_check(self):
def decorate(fn):
def wrapper(self, request, *args, **kwargs):
return fn(self, request, *args, **kwargs)
return wrapper
class ActionViewSet(GenericViewSet):
queryset = Action.objects.all()
@action(detail=False)
@decorate
def wrapped_list_action(self, request, *args, **kwargs):
raise NotImplementedError
view = ActionViewSet()
with pytest.raises(AssertionError) as excinfo:
view.get_extra_actions()
assert str(excinfo.value) == (
'Expected function (`wrapper`) to match its attribute name '
'(`wrapped_list_action`). If using a decorator, ensure the inner '
'function is decorated with `functools.wraps`, or that '
'`wrapper.__name__` is otherwise set to `wrapped_list_action`.')
@override_settings(ROOT_URLCONF='tests.test_viewsets') @override_settings(ROOT_URLCONF='tests.test_viewsets')
class GetExtraActionUrlMapTests(TestCase): class GetExtraActionUrlMapTests(TestCase):
@ -218,6 +264,7 @@ class GetExtraActionUrlMapTests(TestCase):
expected = OrderedDict([ expected = OrderedDict([
('Custom list action', 'http://testserver/api/actions/custom_list_action/'), ('Custom list action', 'http://testserver/api/actions/custom_list_action/'),
('List action', 'http://testserver/api/actions/list_action/'), ('List action', 'http://testserver/api/actions/list_action/'),
('Wrapped list action', 'http://testserver/api/actions/wrapped_list_action/'),
]) ])
self.assertEqual(view.get_extra_action_url_map(), expected) self.assertEqual(view.get_extra_action_url_map(), expected)
@ -229,6 +276,7 @@ class GetExtraActionUrlMapTests(TestCase):
expected = OrderedDict([ expected = OrderedDict([
('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'), ('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'),
('Detail action', 'http://testserver/api/actions/1/detail_action/'), ('Detail action', 'http://testserver/api/actions/1/detail_action/'),
('Wrapped detail action', 'http://testserver/api/actions/1/wrapped_detail_action/'),
# "Unresolvable detail action" excluded, since it's not resolvable # "Unresolvable detail action" excluded, since it's not resolvable
]) ])