diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 688deec88..842a0fc64 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -86,9 +86,10 @@ class SchemaGenerator(object): endpoints = [] for key, link, callback in self.endpoints: method = link.action.upper() - view = callback.cls() + view = self.get_view(callback) view.request = clone_request(request, method) view.format_kwarg = None + try: view.check_permissions(view.request) except exceptions.APIException: @@ -135,6 +136,15 @@ class SchemaGenerator(object): return api_endpoints + def get_view(self, callback): + """ + Return constructed view with respect of overrided attributes by detail_route and list_route + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).iteritems(): + setattr(view, attr, val) + return view + def get_path(self, path_regex): """ Given a URL conf regex, return a URI template string. @@ -165,9 +175,10 @@ class SchemaGenerator(object): if hasattr(callback, 'actions'): return [method.upper() for method in callback.actions.keys()] + view = self.get_view(callback) return [ method for method in - callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') + view.allowed_methods if method not in ('OPTIONS', 'HEAD') ] def get_key(self, path, method, callback): @@ -194,9 +205,7 @@ class SchemaGenerator(object): """ Return a `coreapi.Link` instance for the given endpoint. """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) + view = self.get_view(callback) fields = self.get_path_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 6c02c9d23..5deb5fed8 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -33,15 +33,25 @@ class AnotherSerializer(serializers.Serializer): d = serializers.CharField(required=False) +class ForbidAll(permissions.BasePermission): + def has_permission(self, request, view): + return False + + class ExampleViewSet(ModelViewSet): pagination_class = ExamplePagination permission_classes = [permissions.IsAuthenticatedOrReadOnly] filter_backends = [filters.OrderingFilter] serializer_class = ExampleSerializer - @detail_route(methods=['post'], serializer_class=AnotherSerializer) + @detail_route(methods=['put', 'post'], + serializer_class=AnotherSerializer) def custom_action(self, request, pk): - return super(ExampleSerializer, self).retrieve(self, request) + return super(ExampleSerializer, self).update(self, request) + + @detail_route(permission_classes=[ForbidAll]) + def forbidden_action(self, request, pk): + return super(ExampleSerializer, self).update(self, request) class ExampleView(APIView): @@ -130,6 +140,16 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('pk', required=True, location='path') ] ), + 'custom_action': coreapi.Link( + url='/example/{pk}/custom_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('c', required=True, location='form'), + coreapi.Field('d', required=False, location='form'), + ] + ), 'custom_action': coreapi.Link( url='/example/{pk}/custom_action/', action='post',