diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 55e27ea8f..02ca6c408 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -1,32 +1,60 @@ from django.core.management.base import BaseCommand -from rest_framework.compat import yaml +from rest_framework import renderers +from rest_framework.schemas import coreapi from rest_framework.schemas.openapi import SchemaGenerator -from rest_framework.utils import json +from rest_framework.settings import api_settings + +OPENAPI_MODE = 'openapi' +COREAPI_MODE = 'coreapi' class Command(BaseCommand): help = "Generates configured API schema for project." + def get_mode(self): + default_schema_class = api_settings.DEFAULT_SCHEMA_CLASS + if issubclass(default_schema_class, coreapi.AutoSchema): + return COREAPI_MODE + return OPENAPI_MODE + def add_arguments(self, parser): parser.add_argument('--title', dest="title", default='', type=str) parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str) - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) + if self.get_mode() == COREAPI_MODE: + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + else: + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) def handle(self, *args, **options): - generator = SchemaGenerator( + generator_class = self.get_generator_class() + generator = generator_class( url=options['url'], title=options['title'], description=options['description'] ) - schema = generator.get_schema(request=None, public=True) + renderer = self.get_renderer(options['format']) + output = renderer.render(schema, renderer_context={}) + self.stdout.write(output.decode('utf-8')) - # TODO: Handle via renderer? More options? - if options['format'] == 'openapi': - output = yaml.dump(schema, default_flow_style=False) - else: - output = json.dumps(schema, indent=2) + def get_renderer(self, format): + if self.get_mode() == COREAPI_MODE: + renderer_cls = { + 'corejson': renderers.CoreJSONRenderer, + 'openapi': renderers.CoreAPIOpenAPIRenderer, + 'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer, + }[format] + return renderer_cls() - self.stdout.write(output) + renderer_cls = { + 'openapi': renderers.OpenAPIRenderer, + 'openapi-json': renderers.JSONOpenAPIRenderer, + }[format] + return renderer_cls() + + def get_generator_class(self): + if self.get_mode() == COREAPI_MODE: + return coreapi.SchemaGenerator + return SchemaGenerator diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index f043e6327..4fa2099e4 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -1024,28 +1024,49 @@ class _BaseOpenAPIRenderer: } -class OpenAPIRenderer(_BaseOpenAPIRenderer): +class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer): media_type = 'application/vnd.oai.openapi' charset = None format = 'openapi' def __init__(self): - assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' - assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' + assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.' + assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.' def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) return yaml.dump(structure, default_flow_style=False).encode('utf-8') -class JSONOpenAPIRenderer(_BaseOpenAPIRenderer): +class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer): media_type = 'application/vnd.oai.openapi+json' charset = None format = 'openapi-json' def __init__(self): - assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.' + assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.' def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) return json.dumps(structure, indent=4).encode('utf-8') + + +class OpenAPIRenderer(BaseRenderer): + media_type = 'application/vnd.oai.openapi' + charset = None + format = 'openapi' + + def __init__(self): + assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' + + def render(self, data, media_type=None, renderer_context=None): + return yaml.dump(data, default_flow_style=False).encode('utf-8') + + +class JSONOpenAPIRenderer(BaseRenderer): + media_type = 'application/vnd.oai.openapi+json' + charset = None + format = 'openapi-json' + + def render(self, data, media_type=None, renderer_context=None): + return json.dumps(data, indent=2).encode('utf-8') diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py index f5e327a94..77f9ad27d 100644 --- a/rest_framework/schemas/views.py +++ b/rest_framework/schemas/views.py @@ -20,7 +20,7 @@ class SchemaView(APIView): super(SchemaView, self).__init__(*args, **kwargs) if self.renderer_classes is None: self.renderer_classes = [ - renderers.OpenAPIRenderer, + renderers.CoreAPIOpenAPIRenderer, renderers.CoreJSONRenderer ] if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: diff --git a/tests/schemas/test_managementcommand.py b/tests/schemas/test_managementcommand.py index a4d038e55..9757f382d 100644 --- a/tests/schemas/test_managementcommand.py +++ b/tests/schemas/test_managementcommand.py @@ -8,7 +8,8 @@ from django.test.utils import override_settings from django.utils import six from rest_framework.compat import uritemplate, yaml -from rest_framework.utils import json +from rest_framework.management.commands import generateschema +from rest_framework.utils import formatting, json from rest_framework.views import APIView @@ -30,6 +31,13 @@ class GenerateSchemaTests(TestCase): def setUp(self): self.out = six.StringIO() + def test_command_detects_schema_generation_mode(self): + """Switching between CoreAPI & OpenAPI""" + command = generateschema.Command() + assert command.get_mode() == generateschema.OPENAPI_MODE + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + assert command.get_mode() == generateschema.COREAPI_MODE + @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') def test_renders_default_schema_with_custom_title_url_and_description(self): @@ -49,3 +57,64 @@ class GenerateSchemaTests(TestCase): # Check valid JSON was output. out_json = json.loads(self.out.getvalue()) assert out_json['openapi'] == '3.0.2' + + @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') + @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) + def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self): + expected_out = """info: + description: Sample description + title: SampleAPI + version: '' + openapi: 3.0.0 + paths: + /: + get: + operationId: list + servers: + - url: http://api.sample.com/ + """ + call_command('generateschema', + '--title=SampleAPI', + '--url=http://api.sample.com', + '--description=Sample description', + stdout=self.out) + + self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) + + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) + def test_coreapi_renders_openapi_json_schema(self): + expected_out = { + "openapi": "3.0.0", + "info": { + "version": "", + "title": "", + "description": "" + }, + "servers": [ + { + "url": "" + } + ], + "paths": { + "/": { + "get": { + "operationId": "list" + } + } + } + } + call_command('generateschema', + '--format=openapi-json', + stdout=self.out) + out_json = json.loads(self.out.getvalue()) + + self.assertDictEqual(out_json, expected_out) + + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) + def test_renders_corejson_schema(self): + expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" + call_command('generateschema', + '--format=corejson', + stdout=self.out) + self.assertIn(expected_out, self.out.getvalue())