Handle multiple methods on custom action (#4529)

This commit is contained in:
Tom Christie 2016-09-30 15:23:47 +01:00 committed by GitHub
parent c3a9538ad9
commit 4ad5256e88
2 changed files with 52 additions and 7 deletions

View File

@ -173,6 +173,16 @@ class SchemaGenerator(object):
if self.endpoints is None:
self.endpoints = self.endpoint_inspector.get_api_endpoints()
links = self.get_links(request)
if not links:
return None
return coreapi.Document(title=self.title, url=self.url, content=links)
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = {}
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
@ -181,11 +191,7 @@ class SchemaGenerator(object):
link = self.get_link(path, method, view)
keys = self.get_keys(path, method, view)
insert_into(links, keys, link)
if not links:
return None
return coreapi.Document(title=self.title, url=self.url, content=links)
return links
# Methods used when we generate a view instance from the raw callback...
@ -200,6 +206,7 @@ class SchemaGenerator(object):
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
@ -393,7 +400,11 @@ class SchemaGenerator(object):
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
return named_path_components[:-1] + [action]
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

View File

@ -22,6 +22,10 @@ class ExamplePagination(pagination.PageNumberPagination):
page_size = 100
class EmptySerializer(serializers.Serializer):
pass
class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True, help_text='A field description')
b = serializers.CharField(required=False)
@ -48,6 +52,10 @@ class ExampleViewSet(ModelViewSet):
def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request)
@list_route(methods=['post', 'get'], serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
return super(ExampleViewSet, self).list(self, request)
def get_serializer(self, *args, **kwargs):
assert self.request
assert self.action
@ -85,6 +93,12 @@ class TestRouterGeneratedSchema(TestCase):
url='/example/custom_list_action/',
action='get'
),
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
)
},
'read': coreapi.Link(
url='/example/{pk}/',
action='get',
@ -145,6 +159,16 @@ class TestRouterGeneratedSchema(TestCase):
url='/example/custom_list_action/',
action='get'
),
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
),
'create': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='post'
)
},
'update': coreapi.Link(
url='/example/{pk}/',
action='put',
@ -201,6 +225,7 @@ class TestSchemaGenerator(TestCase):
self.patterns = [
url('^example/?$', ExampleListView.as_view()),
url('^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url('^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -230,7 +255,16 @@ class TestSchemaGenerator(TestCase):
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
),
'sub': {
'list': coreapi.Link(
url='/example/{pk}/sub/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
}
}
}
)