From 3fd1052ef6a5108521cb6f1bf231d85303d74c6d Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Mon, 13 May 2019 15:21:29 -0400 Subject: [PATCH] add OpenAPI schema initialization --- docs/api-guide/schemas.md | 27 +++++++++++++++++++++++++++ rest_framework/schemas/__init__.py | 17 ++++++++++++----- rest_framework/schemas/generators.py | 2 +- rest_framework/schemas/openapi.py | 15 ++++++++++++++- tests/schemas/test_get_schema_view.py | 5 +++++ tests/schemas/test_openapi.py | 23 +++++++++++++++++++++++ 6 files changed, 82 insertions(+), 7 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index b09b1606e..72afd94be 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -327,6 +327,33 @@ May be used to pass a canonical URL for the schema. url='https://www.example.org/api/' ) +#### `openapi_schema` + +May be used to pass a static initial OpenAPI schema document, typically +containing top-level OpenAPI fields. The schema document will +be added to by the AutoSchema generator. + + schema_view = get_schema_view( + openapi_schema = { + 'info': { + 'title': 'my title', + 'version': '1.0', + 'contact': { + 'name': 'API Support', + 'url': 'http://www.example.com/support', + 'email': 'support@example.com' + }, + 'license': { + 'name': 'Apache 2.0', + 'url': 'https://www.apache.org/licenses/LICENSE-2.0.html' + }. + 'servers': [ + {'url': 'https://api.example.com'} + ] + } + ) + + #### `urlconf` A string representing the import path to the URL conf that you want diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index 8fdb2d86a..1b307586b 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -28,7 +28,7 @@ from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa def get_schema_view( - title=None, url=None, description=None, urlconf=None, renderer_classes=None, + title=None, url=None, description=None, openapi_schema=None, urlconf=None, renderer_classes=None, public=False, patterns=None, generator_class=None, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): @@ -41,10 +41,17 @@ def get_schema_view( else: generator_class = openapi.SchemaGenerator - generator = generator_class( - title=title, url=url, description=description, - urlconf=urlconf, patterns=patterns, - ) + if isinstance(generator_class, openapi.SchemaGenerator): + generator = generator_class( + title=title, url=url, description=description, + urlconf=urlconf, patterns=patterns, + openapi_schema=openapi_schema, + ) + else: + generator = generator_class( + title=title, url=url, description=description, + urlconf=urlconf, patterns=patterns, + ) # Avoid import cycle on APIView from .views import SchemaView diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index ecb07f935..780229361 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -163,7 +163,7 @@ class BaseSchemaGenerator(object): # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, **kwargs): if url and not url.endswith('/'): url += '/' diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 44b281be8..a5cdba69b 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -14,6 +14,10 @@ from .utils import get_pk_description, is_list_view class SchemaGenerator(BaseSchemaGenerator): + def __init__(self, **kwargs): + super().__init__(**kwargs) + #: the openapi schema document: + self.openapi_schema = {} def get_info(self): info = { @@ -43,6 +47,9 @@ class SchemaGenerator(BaseSchemaGenerator): subpath = '/' + path[len(prefix):] result.setdefault(subpath, {}) result[subpath][method.lower()] = operation + if hasattr(view.schema, 'openapi_schema'): + # TODO: shallow or deep merge? + self.openapi_schema = {**self.openapi_schema, **view.schema.openapi_schema} return result @@ -61,13 +68,19 @@ class SchemaGenerator(BaseSchemaGenerator): 'info': self.get_info(), 'paths': paths, } + # TODO: shallow or deep merge? + self.openapi_schema = {**schema, **self.openapi_schema} - return schema + return self.openapi_schema # View Inspectors class AutoSchema(ViewInspector): + def __init__(self, openapi_schema={}): + super().__init__() + # TODO: call this manual_fields ala coreapi? + self.openapi_schema = openapi_schema content_types = ['application/json'] method_mapping = { diff --git a/tests/schemas/test_get_schema_view.py b/tests/schemas/test_get_schema_view.py index f582c6495..84b1eaf7c 100644 --- a/tests/schemas/test_get_schema_view.py +++ b/tests/schemas/test_get_schema_view.py @@ -12,6 +12,11 @@ class GetSchemaViewTests(TestCase): assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator) assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes + def test_openapi_initialized(self): + schema_view = get_schema_view(openapi_schema={'info': {'title': 'With OpenAPI'}}) + assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator) + assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes + @pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed') def test_coreapi(self): with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 2ddf54f01..dc563e98f 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -226,6 +226,29 @@ class TestGenerator(TestCase): assert 'openapi' in schema assert 'paths' in schema + assert 'info' in schema + assert 'title' in schema['info'] + assert 'version' in schema['info'] + assert schema['info']['title'] is None + assert schema['info']['version'] == 'TODO' + + def test_schema_initializer(self): + """Construction of top-level dictionary with an initializer.""" + class MyListView(views.ExampleListView): + schema = AutoSchema(openapi_schema={'info': {'title': 'mytitle', 'version': 'myversion'}}) + + patterns = [ + url(r'^example/?$', MyListView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert 'info' in schema + assert 'title' in schema['info'] + assert 'version' in schema['info'] + assert schema['info']['title'] == 'mytitle' and schema['info']['version'] == 'myversion' def test_serializer_datefield(self): patterns = [