diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index 5c9eaa12c..a8abd2a63 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -94,6 +94,12 @@ A content negotiation class, that determines how a renderer is selected for the Default: `'rest_framework.negotiation.DefaultContentNegotiation'` +#### DEFAULT_SCHEMA_CLASS + +A view inspector class that will be used for schema generation. + +Default: `'rest_framework.schemas.AutoSchema'` + --- ## Generic view settings diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index 1af0b9fc5..ba0ec6536 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -23,7 +23,7 @@ Other access should target the submodules directly from rest_framework.settings import api_settings from .generators import SchemaGenerator -from .inspectors import AutoSchema, ManualSchema # noqa +from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa def get_schema_view( diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 008d7c091..b2a5320bd 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """ inspectors.py # Per-endpoint view introspection @@ -456,3 +457,13 @@ class ManualSchema(ViewInspector): ) return self._link + + +class DefaultSchema(object): + """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" + def __get__(self, instance, owner): + inspector_class = api_settings.DEFAULT_SCHEMA_CLASS + assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" + inspector = inspector_class() + inspector.view = instance + return inspector diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 478a5229f..db92b7a7b 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,6 +55,9 @@ DEFAULTS = { 'DEFAULT_PAGINATION_CLASS': None, 'DEFAULT_FILTER_BACKENDS': (), + # Schema + 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, @@ -140,6 +143,7 @@ IMPORT_STRINGS = ( 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', + 'DEFAULT_SCHEMA_CLASS', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', diff --git a/rest_framework/views.py b/rest_framework/views.py index 3140bb9a3..f9ee7fb53 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -18,7 +18,7 @@ from django.views.generic import View from rest_framework import exceptions, status from rest_framework.request import Request from rest_framework.response import Response -from rest_framework.schemas import AutoSchema +from rest_framework.schemas import DefaultSchema from rest_framework.settings import api_settings from rest_framework.utils import formatting @@ -117,7 +117,7 @@ class APIView(View): # Allow dependency injection of other settings to make testing easier. settings = api_settings - schema = AutoSchema() + schema = DefaultSchema() @classmethod def as_view(cls, **initkwargs): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index ba561a959..fa91bac03 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -516,6 +516,11 @@ class Test4605Regression(TestCase): assert prefix == '/' +class CustomViewInspector(AutoSchema): + """A dummy AutoSchema subclass""" + pass + + class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): @@ -523,6 +528,18 @@ class TestAutoSchema(TestCase): assert hasattr(view, 'schema') assert isinstance(view.schema, AutoSchema) + def test_set_custom_inspector_class_on_view(self): + class CustomView(APIView): + schema = CustomViewInspector() + + view = CustomView() + assert isinstance(view.schema, CustomViewInspector) + + def test_set_custom_inspector_class_via_settings(self): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): + view = APIView() + assert isinstance(view.schema, CustomViewInspector) + def test_get_link_requires_instance(self): descriptor = APIView.schema # Accessed from class with pytest.raises(AssertionError):