Switch to using schema = None

This commit is contained in:
Carlton Gibson 2017-09-14 16:23:27 +02:00
parent f346118a8e
commit 9d84e0c290
6 changed files with 24 additions and 15 deletions

View File

@ -75,7 +75,12 @@ def api_view(http_method_names=None, exclude_from_schema=False):
WrappedAPIView.schema = getattr(func, 'schema', WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema) APIView.schema)
if exclude_from_schema:
# This won't catch an explicit `exclude_from_schema=False`
# but it should be good enough.
# TODO: DeprecationWarning
WrappedAPIView.exclude_from_schema = exclude_from_schema WrappedAPIView.exclude_from_schema = exclude_from_schema
return WrappedAPIView.as_view() return WrappedAPIView.as_view()
return decorator return decorator

View File

@ -291,7 +291,7 @@ class APIRootView(views.APIView):
The default basic root view for DefaultRouter The default basic root view for DefaultRouter
""" """
_ignore_model_permissions = True _ignore_model_permissions = True
exclude_from_schema = True schema = None # exclude from schema
api_root_dict = None api_root_dict = None
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):

View File

@ -148,9 +148,14 @@ class EndpointEnumerator(object):
if not is_api_view(callback): if not is_api_view(callback):
return False # Ignore anything except REST framework views. return False # Ignore anything except REST framework views.
if hasattr(callback.cls, 'exclude_from_schema'):
# TODO: deprecation warning
if getattr(callback.cls, 'exclude_from_schema', False): if getattr(callback.cls, 'exclude_from_schema', False):
return False return False
if callback.cls.schema is None:
return False
if path.endswith('.{format}') or path.endswith('.{format}/'): if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs. return False # Ignore .json style URLs.

View File

@ -11,7 +11,7 @@ from rest_framework.views import APIView
class SchemaView(APIView): class SchemaView(APIView):
_ignore_model_permissions = True _ignore_model_permissions = True
exclude_from_schema = True schema = None # exclude from schema
renderer_classes = None renderer_classes = None
schema_generator = None schema_generator = None
public = False public = False

View File

@ -112,8 +112,6 @@ class APIView(View):
# Allow dependency injection of other settings to make testing easier. # Allow dependency injection of other settings to make testing easier.
settings = api_settings settings = api_settings
# Mark the view as being included or excluded from schema generation.
exclude_from_schema = False
schema = AutoSchema() schema = AutoSchema()
@classmethod @classmethod

View File

@ -8,7 +8,9 @@ from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import detail_route, list_route, api_view, schema from rest_framework.decorators import (
api_view, detail_route, list_route, schema
)
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.schemas import ( from rest_framework.schemas import (
@ -618,13 +620,14 @@ def test_docstring_is_not_stripped_by_get_description():
# Views for SchemaGenerationExclusionTests # Views for SchemaGenerationExclusionTests
class ExcludedAPIView(APIView): class ExcludedAPIView(APIView):
exclude_from_schema = True schema = None
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
pass pass
@api_view(['GET'], exclude_from_schema=True) @api_view(['GET'])
@schema(None)
def excluded_fbv(request): def excluded_fbv(request):
pass pass
@ -670,15 +673,13 @@ class SchemaGenerationExclusionTests(TestCase):
path, method, callback = endpoints[0] path, method, callback = endpoints[0]
assert path == '/included-fbv/' assert path == '/included-fbv/'
def test_should_include_endpoint_excludes_correctly(self): def test_should_include_endpoint_excludes_correctly(self):
"""This is the specific method that should handle the exclusion""" """This is the specific method that should handle the exclusion"""
inspector = EndpointEnumerator(self.patterns) inspector = EndpointEnumerator(self.patterns)
pairs = [ # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
(inspector.get_path_from_regex(pattern.regex.pattern), pattern.callback) pairs = [(inspector.get_path_from_regex(pattern.regex.pattern), pattern.callback)
for pattern in self.patterns for pattern in self.patterns]
]
should_include = [ should_include = [
inspector.should_include_endpoint(*pair) for pair in pairs inspector.should_include_endpoint(*pair) for pair in pairs