diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 28f5dc8a1..a6ebf3ca8 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -182,15 +182,21 @@ class SchemaGenerator(object): Return a tuple of strings, indicating the identity to use for a given endpoint. eg. ('users', 'list'). """ - category = None - for item in path.strip('/').split('/'): - if '{' in item: - break - category = item - actions = getattr(callback, 'actions', self.default_mapping) action = actions[method.lower()] + list_route = None + if hasattr(callback.cls, action): + list_route_action = getattr(callback.cls, action) + if hasattr(list_route_action, 'bind_to_methods') and not getattr(list_route_action, 'detail', None): + list_route = list_route_action.kwargs.get('url_path', action) + + category = None + for item in path.strip('/').split('/'): + if '{' in item or item == list_route: + break + category = item + if category: return (category, action) return (action,) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index d8c0f2209..03fb76e13 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,7 +5,7 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi -from rest_framework.decorators import detail_route +from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator @@ -40,8 +40,12 @@ class ExampleViewSet(ModelViewSet): serializer_class = ExampleSerializer @detail_route(methods=['post'], serializer_class=AnotherSerializer) - def custom_action(self, request, pk): - return super(ExampleSerializer, self).retrieve(self, request) + def custom_detail_action(self, request, pk): + return super(ExampleViewSet, self).retrieve(self, request) + + @list_route() + def custom_list_action(self, request): + return super(ExampleViewSet, self).list(self, request) def get_serializer(self, *args, **kwargs): assert self.request @@ -94,6 +98,10 @@ class TestRouterGeneratedSchema(TestCase): fields=[ coreapi.Field('pk', required=True, location='path') ] + ), + 'custom_list_action': coreapi.Link( + url='/example/custom_list_action/', + action='get' ) } } @@ -134,8 +142,8 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('pk', required=True, location='path') ] ), - 'custom_action': coreapi.Link( - url='/example/{pk}/custom_action/', + 'custom_detail_action': coreapi.Link( + url='/example/{pk}/custom_detail_action/', action='post', encoding='application/json', fields=[ @@ -144,6 +152,10 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('d', required=False, location='form'), ] ), + 'custom_list_action': coreapi.Link( + url='/example/custom_list_action/', + action='get' + ), 'update': coreapi.Link( url='/example/{pk}/', action='put',