From d99550e54c3ebcafe5c9a67e1736d4062a13a867 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Thu, 9 Aug 2018 09:19:23 +0200 Subject: [PATCH] Add Open API get_schema(). --- rest_framework/schemas/generators.py | 26 +++++++++++++++++++++++--- tests/schemas/test_openapi.py | 19 +++++++++++++++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 92cec8aa1..18693bed0 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -283,9 +283,7 @@ class SchemaGenerator(object): """ Generate a `coreapi.Document` representing the API schema. """ - if self.endpoints is None: - inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) - self.endpoints = inspector.get_api_endpoints() + self._initialise_endpoints() links = self.get_links(None if public else request) if not links: @@ -301,6 +299,11 @@ class SchemaGenerator(object): url=url, content=links ) + def _initialise_endpoints(self): + if self.endpoints is None: + inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) + self.endpoints = inspector.get_api_endpoints() + def get_links(self, request=None): """ Return a dictionary containing all the links that should be @@ -491,3 +494,20 @@ class OpenAPISchemaGenerator(SchemaGenerator): result[subpath][method.lower()] = operation return result + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + paths = self.get_paths(None if public else request) + if not paths: + return None + + schema = { + 'basePath': self.url, + 'paths': paths, + } + + return schema diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index bdd5dc9ba..3b4882f31 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -47,14 +47,12 @@ class TestGenerator(TestCase): assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema) def test_paths_construction(self): + """Construction of the `paths` key.""" patterns = [ url(r'^example/?$', views.ExampleListView.as_view()), ] generator = OpenAPISchemaGenerator(patterns=patterns) - - # This happens in get_schema() - inspector = generator.endpoint_inspector_cls(generator.patterns, generator.urlconf) - generator.endpoints = inspector.get_api_endpoints() + generator._initialise_endpoints() paths = generator.get_paths() @@ -63,3 +61,16 @@ class TestGenerator(TestCase): assert len(example_operations) == 2 assert 'get' in example_operations assert 'post' in example_operations + + def test_schema_construction(self): + """Construction of the top level dictionary.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = OpenAPISchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert 'basePath' in schema + assert 'paths' in schema