diff --git a/rest_framework/management/commands/generate_schema.py b/rest_framework/management/commands/generate_schema.py index bc4af6a94..ed123a20d 100644 --- a/rest_framework/management/commands/generate_schema.py +++ b/rest_framework/management/commands/generate_schema.py @@ -1,7 +1,8 @@ from django.core.management.base import BaseCommand -from rest_framework.renderers import CoreJSONRenderer -from rest_framework.schemas import SchemaGenerator +from rest_framework.renderers import CoreJSONRenderer, JSONRenderer +from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.settings import api_settings class Command(BaseCommand): @@ -26,11 +27,21 @@ class Command(BaseCommand): pass def handle(self, *args, **options): + generator_class = self._get_generator_class() + generator = generator_class() - renderer = CoreJSONRenderer() - generator = SchemaGenerator() schema = generator.get_schema(request=None, public=True) + renderer = self._get_renderer(generator) rendered_schema = renderer.render(schema, renderer_context={}).decode('utf8') self.stdout.write(rendered_schema) + + def _get_generator_class(self): + return api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS + + def _get_renderer(self, generator): + if isinstance(generator, OpenAPISchemaGenerator): + return JSONRenderer() + else: + return CoreJSONRenderer() diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 6c581f8e8..61f3d3b8a 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -57,6 +57,7 @@ DEFAULTS = { # Schema 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + 'DEFAULT_SCHEMA_GENERATOR_CLASS': 'rest_framework.schemas.generators.SchemaGenerator', # Throttling 'DEFAULT_THROTTLE_RATES': { @@ -144,6 +145,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', 'DEFAULT_SCHEMA_CLASS', + 'DEFAULT_SCHEMA_GENERATOR_CLASS', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER',