diff --git a/rest_framework/management/commands/generate_schema.py b/rest_framework/management/commands/generate_schema.py index 54a062176..9ac17ec9c 100644 --- a/rest_framework/management/commands/generate_schema.py +++ b/rest_framework/management/commands/generate_schema.py @@ -1,101 +1,8 @@ -import urllib.parse as urlparse - from django.core.management.base import BaseCommand -from rest_framework.compat import coreapi, coreschema -from rest_framework.renderers import CoreJSONRenderer, JSONRenderer -from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.compat import coreapi +from rest_framework.renderers import CoreJSONRenderer, OpenAPIRenderer from rest_framework.settings import api_settings -from rest_framework.utils import json - - -class OpenAPICodec: - CLASS_TO_TYPENAME = { - coreschema.Object: 'object', - coreschema.Array: 'array', - coreschema.Number: 'number', - coreschema.Integer: 'integer', - coreschema.String: 'string', - coreschema.Boolean: 'boolean', - } - - def get_schema(self, instance): - schema = {} - if instance.__class__ in self.CLASS_TO_TYPENAME: - schema['type'] = self.CLASS_TO_TYPENAME[instance.__class__] - schema['title'] = instance.title, - schema['description'] = instance.description - if hasattr(instance, 'enum'): - schema['enum'] = instance.enum - return schema - - def get_parameters(self, link): - parameters = [] - for field in link.fields: - if field.location not in ['path', 'query']: - continue - parameter = { - 'name': field.name, - 'in': field.location, - } - if field.required: - parameter['required'] = True - if field.description: - parameter['description'] = field.description - if field.schema: - parameter['schema'] = self.get_schema(field.schema) - parameters.append(parameter) - return parameters - - def get_operation(self, link, name, tag): - operation_id = "%s_%s" % (tag, name) if tag else name - parameters = self.get_parameters(link) - - operation = { - 'operationId': operation_id, - } - if link.title: - operation['summary'] = link.title - if link.description: - operation['description'] = link.description - if parameters: - operation['parameters'] = parameters - if tag: - operation['tags'] = [tag] - return operation - - def get_paths(self, document): - paths = {} - - tag = None - for name, link in document.links.items(): - path = urlparse.urlparse(link.url).path - method = link.action.lower() - paths.setdefault(path, {}) - paths[path][method] = self.get_operation(link, name, tag=tag) - - for tag, section in document.data.items(): - for name, link in section.links.items(): - path = urlparse.urlparse(link.url).path - method = link.action.lower() - paths.setdefault(path, {}) - paths[path][method] = self.get_operation(link, name, tag=tag) - - return paths - - def encode(self, document): - return json.dumps({ - 'openapi': '3.0.0', - 'info': { - 'version': '', - 'title': document.title, - 'description': document.description - }, - 'servers': [{ - 'url': document.url - }], - 'paths': self.get_paths(document) - }, indent=4) class Command(BaseCommand): @@ -122,20 +29,18 @@ class Command(BaseCommand): def handle(self, *args, **options): assert coreapi is not None, 'coreapi must be installed.' - generator_class = self._get_generator_class() + generator_class = api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS() generator = generator_class() schema = generator.get_schema(request=None, public=True) - codec = OpenAPICodec() - output = codec.encode(schema) + + renderer = self.get_renderer('openapi') + output = renderer.render(schema) self.stdout.write(output) - def _get_generator_class(self): - return api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS - - def _get_renderer(self, generator): - if isinstance(generator, OpenAPISchemaGenerator): - return JSONRenderer() - else: - return CoreJSONRenderer() + def get_renderer(self, format): + return { + 'corejson': CoreJSONRenderer(), + 'openapi': OpenAPIRenderer() + } diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a9645cc89..1b850db92 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -9,6 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API. from __future__ import unicode_literals import base64 +import urllib.parse as urlparse from collections import OrderedDict from django import forms @@ -24,7 +25,7 @@ from django.utils.html import mark_safe from rest_framework import VERSION, exceptions, serializers, status from rest_framework.compat import ( - INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, + INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema, pygments_css ) from rest_framework.exceptions import ParseError @@ -932,3 +933,95 @@ class CoreJSONRenderer(BaseRenderer): indent = bool(renderer_context.get('indent', 0)) codec = coreapi.codecs.CoreJSONCodec() return codec.dump(data, indent=indent) + + +class OpenAPIRenderer: + CLASS_TO_TYPENAME = { + coreschema.Object: 'object', + coreschema.Array: 'array', + coreschema.Number: 'number', + coreschema.Integer: 'integer', + coreschema.String: 'string', + coreschema.Boolean: 'boolean', + } + + def __init__(self): + assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' + + def get_schema(self, instance): + schema = {} + if instance.__class__ in self.CLASS_TO_TYPENAME: + schema['type'] = self.CLASS_TO_TYPENAME[instance.__class__] + schema['title'] = instance.title, + schema['description'] = instance.description + if hasattr(instance, 'enum'): + schema['enum'] = instance.enum + return schema + + def get_parameters(self, link): + parameters = [] + for field in link.fields: + if field.location not in ['path', 'query']: + continue + parameter = { + 'name': field.name, + 'in': field.location, + } + if field.required: + parameter['required'] = True + if field.description: + parameter['description'] = field.description + if field.schema: + parameter['schema'] = self.get_schema(field.schema) + parameters.append(parameter) + return parameters + + def get_operation(self, link, name, tag): + operation_id = "%s_%s" % (tag, name) if tag else name + parameters = self.get_parameters(link) + + operation = { + 'operationId': operation_id, + } + if link.title: + operation['summary'] = link.title + if link.description: + operation['description'] = link.description + if parameters: + operation['parameters'] = parameters + if tag: + operation['tags'] = [tag] + return operation + + def get_paths(self, document): + paths = {} + + tag = None + for name, link in document.links.items(): + path = urlparse.urlparse(link.url).path + method = link.action.lower() + paths.setdefault(path, {}) + paths[path][method] = self.get_operation(link, name, tag=tag) + + for tag, section in document.data.items(): + for name, link in section.links.items(): + path = urlparse.urlparse(link.url).path + method = link.action.lower() + paths.setdefault(path, {}) + paths[path][method] = self.get_operation(link, name, tag=tag) + + return paths + + def render(self, data, media_type=None, renderer_context=None): + return json.dumps({ + 'openapi': '3.0.0', + 'info': { + 'version': '', + 'title': data.title, + 'description': data.description + }, + 'servers': [{ + 'url': data.url + }], + 'paths': self.get_paths(data) + }, indent=4)