Schemas should be able to handle multiple methods on custom action

This commit is contained in:
Tom Christie 2016-09-30 15:01:13 +01:00
parent 0844cb5114
commit f7b89a8dad
2 changed files with 53 additions and 7 deletions

View File

@ -173,6 +173,16 @@ class SchemaGenerator(object):
if self.endpoints is None: if self.endpoints is None:
self.endpoints = self.endpoint_inspector.get_api_endpoints() self.endpoints = self.endpoint_inspector.get_api_endpoints()
links = self.get_links(request)
if not links:
return None
return coreapi.Document(title=self.title, url=self.url, content=links)
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = {} links = {}
for path, method, callback in self.endpoints: for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request) view = self.create_view(callback, method, request)
@ -181,11 +191,7 @@ class SchemaGenerator(object):
link = self.get_link(path, method, view) link = self.get_link(path, method, view)
keys = self.get_keys(path, method, view) keys = self.get_keys(path, method, view)
insert_into(links, keys, link) insert_into(links, keys, link)
return links
if not links:
return None
return coreapi.Document(title=self.title, url=self.url, content=links)
# Methods used when we generate a view instance from the raw callback... # Methods used when we generate a view instance from the raw callback...
@ -200,6 +206,7 @@ class SchemaGenerator(object):
view.kwargs = {} view.kwargs = {}
view.format_kwarg = None view.format_kwarg = None
view.request = None view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None) actions = getattr(callback, 'actions', None)
if actions is not None: if actions is not None:
@ -393,7 +400,11 @@ class SchemaGenerator(object):
if is_custom_action(action): if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/" # Custom action, eg "/users/{pk}/activate/", "/users/active/"
return named_path_components[:-1] + [action] if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
# Default action, eg "/users/", "/users/{pk}/" # Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action] return named_path_components + [action]

View File

@ -22,6 +22,10 @@ class ExamplePagination(pagination.PageNumberPagination):
page_size = 100 page_size = 100
class EmptySerializer(serializers.Serializer):
pass
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True, help_text='A field description') a = serializers.CharField(required=True, help_text='A field description')
b = serializers.CharField(required=False) b = serializers.CharField(required=False)
@ -48,6 +52,10 @@ class ExampleViewSet(ModelViewSet):
def custom_list_action(self, request): def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request) return super(ExampleViewSet, self).list(self, request)
@list_route(methods=['post', 'get'], serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
return super(ExampleViewSet, self).list(self, request)
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
assert self.request assert self.request
assert self.action assert self.action
@ -85,6 +93,12 @@ class TestRouterGeneratedSchema(TestCase):
url='/example/custom_list_action/', url='/example/custom_list_action/',
action='get' action='get'
), ),
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
)
},
'read': coreapi.Link( 'read': coreapi.Link(
url='/example/{pk}/', url='/example/{pk}/',
action='get', action='get',
@ -95,6 +109,7 @@ class TestRouterGeneratedSchema(TestCase):
} }
} }
) )
print response.data
self.assertEqual(response.data, expected) self.assertEqual(response.data, expected)
def test_authenticated_request(self): def test_authenticated_request(self):
@ -145,6 +160,16 @@ class TestRouterGeneratedSchema(TestCase):
url='/example/custom_list_action/', url='/example/custom_list_action/',
action='get' action='get'
), ),
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
),
'create': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='post'
)
},
'update': coreapi.Link( 'update': coreapi.Link(
url='/example/{pk}/', url='/example/{pk}/',
action='put', action='put',
@ -201,6 +226,7 @@ class TestSchemaGenerator(TestCase):
self.patterns = [ self.patterns = [
url('^example/?$', ExampleListView.as_view()), url('^example/?$', ExampleListView.as_view()),
url('^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()), url('^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url('^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -230,7 +256,16 @@ class TestSchemaGenerator(TestCase):
fields=[ fields=[
coreapi.Field('pk', required=True, location='path') coreapi.Field('pk', required=True, location='path')
] ]
) ),
'sub': {
'list': coreapi.Link(
url='/example/{pk}/sub/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
}
} }
} }
) )