diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1297f96b4..f5a0c3579 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -75,7 +75,12 @@ def api_view(http_method_names=None, exclude_from_schema=False): WrappedAPIView.schema = getattr(func, 'schema', APIView.schema) - WrappedAPIView.exclude_from_schema = exclude_from_schema + if exclude_from_schema: + # This won't catch an explicit `exclude_from_schema=False` + # but it should be good enough. + # TODO: DeprecationWarning + WrappedAPIView.exclude_from_schema = exclude_from_schema + return WrappedAPIView.as_view() return decorator diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 01daa7e7d..3b5ef46d8 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -291,7 +291,7 @@ class APIRootView(views.APIView): The default basic root view for DefaultRouter """ _ignore_model_permissions = True - exclude_from_schema = True + schema = None # exclude from schema api_root_dict = None def get(self, request, *args, **kwargs): diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index cc1ffb31b..21726ef26 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -148,7 +148,12 @@ class EndpointEnumerator(object): if not is_api_view(callback): return False # Ignore anything except REST framework views. - if getattr(callback.cls, 'exclude_from_schema', False): + if hasattr(callback.cls, 'exclude_from_schema'): + # TODO: deprecation warning + if getattr(callback.cls, 'exclude_from_schema', False): + return False + + if callback.cls.schema is None: return False if path.endswith('.{format}') or path.endswith('.{format}/'): diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py index 932b5a487..b13eadea9 100644 --- a/rest_framework/schemas/views.py +++ b/rest_framework/schemas/views.py @@ -11,7 +11,7 @@ from rest_framework.views import APIView class SchemaView(APIView): _ignore_model_permissions = True - exclude_from_schema = True + schema = None # exclude from schema renderer_classes = None schema_generator = None public = False diff --git a/rest_framework/views.py b/rest_framework/views.py index ccc2047ee..dfed15888 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -112,8 +112,6 @@ class APIView(View): # Allow dependency injection of other settings to make testing easier. settings = api_settings - # Mark the view as being included or excluded from schema generation. - exclude_from_schema = False schema = AutoSchema() @classmethod diff --git a/tests/test_schemas.py b/tests/test_schemas.py index f67dfda4f..1ecd31314 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -8,7 +8,9 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi, coreschema -from rest_framework.decorators import detail_route, list_route, api_view, schema +from rest_framework.decorators import ( + api_view, detail_route, list_route, schema +) from rest_framework.request import Request from rest_framework.routers import DefaultRouter from rest_framework.schemas import ( @@ -618,13 +620,14 @@ def test_docstring_is_not_stripped_by_get_description(): # Views for SchemaGenerationExclusionTests class ExcludedAPIView(APIView): - exclude_from_schema = True + schema = None def get(self, request, *args, **kwargs): pass -@api_view(['GET'], exclude_from_schema=True) +@api_view(['GET']) +@schema(None) def excluded_fbv(request): pass @@ -670,15 +673,13 @@ class SchemaGenerationExclusionTests(TestCase): path, method, callback = endpoints[0] assert path == '/included-fbv/' - def test_should_include_endpoint_excludes_correctly(self): """This is the specific method that should handle the exclusion""" inspector = EndpointEnumerator(self.patterns) - pairs = [ - (inspector.get_path_from_regex(pattern.regex.pattern), pattern.callback) - for pattern in self.patterns - ] + # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test + pairs = [(inspector.get_path_from_regex(pattern.regex.pattern), pattern.callback) + for pattern in self.patterns] should_include = [ inspector.should_include_endpoint(*pair) for pair in pairs @@ -689,4 +690,4 @@ class SchemaGenerationExclusionTests(TestCase): assert should_include == expected def test_deprecations(self): - pass \ No newline at end of file + pass