From c49bb59c37a97d46235b85ae02d294eb935875a9 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Mon, 29 Apr 2019 15:08:04 +0200 Subject: [PATCH] Allow SchemaView to handle both CoreAPI & OpenAPI. --- .../management/commands/generateschema.py | 6 +----- rest_framework/schemas/__init__.py | 14 +++++++++++--- rest_framework/schemas/coreapi.py | 5 +++++ rest_framework/schemas/views.py | 15 +++++++++++---- tests/schemas/test_coreapi.py | 11 ++++++----- tests/schemas/test_get_schema_view.py | 18 ++++++++++++++++++ 6 files changed, 52 insertions(+), 17 deletions(-) create mode 100644 tests/schemas/test_get_schema_view.py diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 02ca6c408..5a324c897 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -3,7 +3,6 @@ from django.core.management.base import BaseCommand from rest_framework import renderers from rest_framework.schemas import coreapi from rest_framework.schemas.openapi import SchemaGenerator -from rest_framework.settings import api_settings OPENAPI_MODE = 'openapi' COREAPI_MODE = 'coreapi' @@ -13,10 +12,7 @@ class Command(BaseCommand): help = "Generates configured API schema for project." def get_mode(self): - default_schema_class = api_settings.DEFAULT_SCHEMA_CLASS - if issubclass(default_schema_class, coreapi.AutoSchema): - return COREAPI_MODE - return OPENAPI_MODE + return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE def add_arguments(self, parser): parser.add_argument('--title', dest="title", default='', type=str) diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index b1f37987a..8fdb2d86a 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -22,24 +22,32 @@ Other access should target the submodules directly """ from rest_framework.settings import api_settings +from . import coreapi, openapi from .inspectors import DefaultSchema # noqa from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa def get_schema_view( title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=SchemaGenerator, + public=False, patterns=None, generator_class=None, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): """ Return a schema view. """ - # Avoid import cycle on APIView - from .views import SchemaView + if generator_class is None: + if coreapi.is_enabled(): + generator_class = coreapi.SchemaGenerator + else: + generator_class = openapi.SchemaGenerator + generator = generator_class( title=title, url=url, description=description, urlconf=urlconf, patterns=patterns, ) + + # Avoid import cycle on APIView + from .views import SchemaView return SchemaView.as_view( renderer_classes=renderer_classes, schema_generator=generator, diff --git a/rest_framework/schemas/coreapi.py b/rest_framework/schemas/coreapi.py index 895ea0efd..b178d7c40 100644 --- a/rest_framework/schemas/coreapi.py +++ b/rest_framework/schemas/coreapi.py @@ -609,3 +609,8 @@ class ManualSchema(ViewInspector): fields=self._fields, description=self._description ) + + +def is_enabled(): + """Is CoreAPI Mode enabled?""" + return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema) diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py index 77f9ad27d..e2be59725 100644 --- a/rest_framework/schemas/views.py +++ b/rest_framework/schemas/views.py @@ -5,6 +5,7 @@ See schemas.__init__.py for package overview. """ from rest_framework import exceptions, renderers from rest_framework.response import Response +from rest_framework.schemas import coreapi from rest_framework.settings import api_settings from rest_framework.views import APIView @@ -19,10 +20,16 @@ class SchemaView(APIView): def __init__(self, *args, **kwargs): super(SchemaView, self).__init__(*args, **kwargs) if self.renderer_classes is None: - self.renderer_classes = [ - renderers.CoreAPIOpenAPIRenderer, - renderers.CoreJSONRenderer - ] + if coreapi.is_enabled(): + self.renderer_classes = [ + renderers.CoreAPIOpenAPIRenderer, + renderers.CoreJSONRenderer + ] + else: + self.renderer_classes = [ + renderers.OpenAPIRenderer, + renderers.JSONOpenAPIRenderer, + ] if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: self.renderer_classes += [renderers.BrowsableAPIRenderer] diff --git a/tests/schemas/test_coreapi.py b/tests/schemas/test_coreapi.py index db0d5e4c8..028a1630a 100644 --- a/tests/schemas/test_coreapi.py +++ b/tests/schemas/test_coreapi.py @@ -134,11 +134,12 @@ class ExampleViewSet(ModelViewSet): pass -if coreapi: - schema_view = get_schema_view(title='Example API') -else: - def schema_view(request): - pass +with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + if coreapi: + schema_view = get_schema_view(title='Example API') + else: + def schema_view(request): + pass router = DefaultRouter() router.register('example', ExampleViewSet, basename='example') diff --git a/tests/schemas/test_get_schema_view.py b/tests/schemas/test_get_schema_view.py new file mode 100644 index 000000000..873480218 --- /dev/null +++ b/tests/schemas/test_get_schema_view.py @@ -0,0 +1,18 @@ +from django.test import TestCase, override_settings + +from rest_framework import renderers +from rest_framework.schemas import coreapi, get_schema_view, openapi + + +class GetSchemaViewTests(TestCase): + """For the get_schema_view() helper.""" + def test_openapi(self): + schema_view = get_schema_view(title="With OpenAPI") + assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator) + assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes + + def test_coreapi(self): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + schema_view = get_schema_view(title="With CoreAPI") + assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator) + assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes