diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 39dd8d910..1c2f1a546 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -173,6 +173,16 @@ class SchemaGenerator(object): if self.endpoints is None: self.endpoints = self.endpoint_inspector.get_api_endpoints() + links = self.get_links(request) + if not links: + return None + return coreapi.Document(title=self.title, url=self.url, content=links) + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ links = {} for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) @@ -181,11 +191,7 @@ class SchemaGenerator(object): link = self.get_link(path, method, view) keys = self.get_keys(path, method, view) insert_into(links, keys, link) - - if not links: - return None - - return coreapi.Document(title=self.title, url=self.url, content=links) + return links # Methods used when we generate a view instance from the raw callback... @@ -200,6 +206,7 @@ class SchemaGenerator(object): view.kwargs = {} view.format_kwarg = None view.request = None + view.action_map = getattr(callback, 'actions', None) actions = getattr(callback, 'actions', None) if actions is not None: @@ -393,7 +400,11 @@ class SchemaGenerator(object): if is_custom_action(action): # Custom action, eg "/users/{pk}/activate/", "/users/active/" - return named_path_components[:-1] + [action] + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1d98a4618..588124fd7 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -22,6 +22,10 @@ class ExamplePagination(pagination.PageNumberPagination): page_size = 100 +class EmptySerializer(serializers.Serializer): + pass + + class ExampleSerializer(serializers.Serializer): a = serializers.CharField(required=True, help_text='A field description') b = serializers.CharField(required=False) @@ -48,6 +52,10 @@ class ExampleViewSet(ModelViewSet): def custom_list_action(self, request): return super(ExampleViewSet, self).list(self, request) + @list_route(methods=['post', 'get'], serializer_class=EmptySerializer) + def custom_list_action_multiple_methods(self, request): + return super(ExampleViewSet, self).list(self, request) + def get_serializer(self, *args, **kwargs): assert self.request assert self.action @@ -85,6 +93,12 @@ class TestRouterGeneratedSchema(TestCase): url='/example/custom_list_action/', action='get' ), + 'custom_list_action_multiple_methods': { + 'read': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='get' + ) + }, 'read': coreapi.Link( url='/example/{pk}/', action='get', @@ -95,6 +109,7 @@ class TestRouterGeneratedSchema(TestCase): } } ) + print response.data self.assertEqual(response.data, expected) def test_authenticated_request(self): @@ -145,6 +160,16 @@ class TestRouterGeneratedSchema(TestCase): url='/example/custom_list_action/', action='get' ), + 'custom_list_action_multiple_methods': { + 'read': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='get' + ), + 'create': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='post' + ) + }, 'update': coreapi.Link( url='/example/{pk}/', action='put', @@ -201,6 +226,7 @@ class TestSchemaGenerator(TestCase): self.patterns = [ url('^example/?$', ExampleListView.as_view()), url('^example/(?P\d+)/?$', ExampleDetailView.as_view()), + url('^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -230,7 +256,16 @@ class TestSchemaGenerator(TestCase): fields=[ coreapi.Field('pk', required=True, location='path') ] - ) + ), + 'sub': { + 'list': coreapi.Link( + url='/example/{pk}/sub/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ) + } } } )