From bb12651a523be8aba8b738c5d24964506b4ceb57 Mon Sep 17 00:00:00 2001 From: Thorsten Franzel Date: Mon, 16 Dec 2019 22:22:30 +0100 Subject: [PATCH] added OpenAPI3 schemas factored out as components - minimal subset of features from PR #7089 - adapted tests --- rest_framework/schemas/openapi.py | 174 ++++++++++------ tests/schemas/test_openapi.py | 322 ++++++++++++++++++------------ 2 files changed, 309 insertions(+), 187 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 58788bc23..4bd1b31e4 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -18,7 +18,20 @@ from .inspectors import ViewInspector from .utils import get_pk_description, is_list_view +class ComponentRegistry: + def __init__(self): + self.schemas = {} + + def get_components(self): + return { + 'schemas': self.schemas, + } + + class SchemaGenerator(BaseSchemaGenerator): + def __init__(self, *args, **kwargs): + self.registry = ComponentRegistry() + super().__init__(*args, **kwargs) def get_info(self): # Title and version are required by openapi specification 3.x @@ -32,7 +45,7 @@ class SchemaGenerator(BaseSchemaGenerator): return info - def get_paths(self, request=None): + def parse(self, request=None): result = {} paths, view_endpoints = self._get_paths_and_endpoints(request) @@ -44,7 +57,10 @@ class SchemaGenerator(BaseSchemaGenerator): for path, method, view in view_endpoints: if not self.has_view_permissions(path, method, view): continue - operation = view.schema.get_operation(path, method) + # keep reference to schema as every access yields a fresh object (descriptor protocol) + schema = view.schema + schema.init(self.registry) + operation = schema.get_operation(path, method) # Normalise path for any provided mount url. if path.startswith('/'): path = path[1:] @@ -61,7 +77,7 @@ class SchemaGenerator(BaseSchemaGenerator): """ self._initialise_endpoints() - paths = self.get_paths(None if public else request) + paths = self.parse(None if public else request) if not paths: return None @@ -69,6 +85,7 @@ class SchemaGenerator(BaseSchemaGenerator): 'openapi': '3.0.2', 'info': self.get_info(), 'paths': paths, + 'components': self.registry.get_components(), } return schema @@ -89,6 +106,9 @@ class AutoSchema(ViewInspector): 'delete': 'Destroy', } + def init(self, registry): + self.registry = registry + def get_operation(self, path, method): operation = {} @@ -104,10 +124,18 @@ class AutoSchema(ViewInspector): request_body = self._get_request_body(path, method) if request_body: operation['requestBody'] = request_body - operation['responses'] = self._get_responses(path, method) + operation['responses'] = self._get_response_bodies(path, method) return operation + def get_request_serializer(self, path, method): + """ override this for custom behaviour """ + return self._get_serializer(path, method) + + def get_response_serializer(self, path, method): + """ override this for custom behaviour """ + return self._get_serializer(path, method) + def _get_operation_id(self, path, method): """ Compute an operation ID from the model, serializer or view name. @@ -218,16 +246,16 @@ class AutoSchema(ViewInspector): return paginator.get_schema_operation_parameters(view) - def _map_field(self, field): + def _map_field(self, method, field): # Nested Serializers, `many` or not. if isinstance(field, serializers.ListSerializer): return { 'type': 'array', - 'items': self._map_serializer(field.child) + 'items': self.resolve_serializer(method, field.child) } if isinstance(field, serializers.Serializer): - data = self._map_serializer(field) + data = self.resolve_serializer(method, field) data['type'] = 'object' return data @@ -268,7 +296,7 @@ class AutoSchema(ViewInspector): 'items': {}, } if not isinstance(field.child, _UnvalidatedField): - map_field = self._map_field(field.child) + map_field = self._map_field(method, field.child) items = { "type": map_field.get('type') } @@ -370,7 +398,7 @@ class AutoSchema(ViewInspector): if field.min_value: content['minimum'] = field.min_value - def _map_serializer(self, serializer): + def _map_serializer(self, method, serializer): # Assuming we have a valid serializer instance. # TODO: # - field is Nested or List serializer. @@ -386,7 +414,7 @@ class AutoSchema(ViewInspector): if field.required: required.append(field.field_name) - schema = self._map_field(field) + schema = self._map_field(method, field) if field.read_only: schema['readOnly'] = True if field.write_only: @@ -404,7 +432,7 @@ class AutoSchema(ViewInspector): result = { 'properties': properties } - if required: + if required and method != 'PATCH': result['required'] = required return result @@ -485,70 +513,98 @@ class AutoSchema(ViewInspector): self.request_media_types = self.map_parsers(path, method) - serializer = self._get_serializer(path, method) + serializer = self.get_request_serializer(path, method) - if not isinstance(serializer, serializers.Serializer): + if isinstance(serializer, serializers.Serializer): + schema = self.resolve_serializer(method, serializer) + else: + schema = { + 'type': 'object', + 'additionalProperties': {}, # https://github.com/swagger-api/swagger-codegen/issues/1318 + 'description': 'Unspecified request body', + } + + # serializer has no fields so skip content enumeration + if not schema: return {} - content = self._map_serializer(serializer) - # No required fields for PATCH - if method == 'PATCH': - content.pop('required', None) - # No read_only fields for request. - for name, schema in content['properties'].copy().items(): - if 'readOnly' in schema: - del content['properties'][name] - return { 'content': { - ct: {'schema': content} - for ct in self.request_media_types + mt: {'schema': schema} for mt in self.request_media_types } } - def _get_responses(self, path, method): - # TODO: Handle multiple codes and pagination classes. - if method == 'DELETE': - return { - '204': { - 'description': '' - } - } - - self.response_media_types = self.map_renderers(path, method) - - item_schema = {} - serializer = self._get_serializer(path, method) + def _get_response_bodies(self, path, method): + serializer = self.get_response_serializer(path, method) if isinstance(serializer, serializers.Serializer): - item_schema = self._map_serializer(serializer) - # No write_only fields for response. - for name, schema in item_schema['properties'].copy().items(): - if 'writeOnly' in schema: - del item_schema['properties'][name] - if 'required' in item_schema: - item_schema['required'] = [f for f in item_schema['required'] if f != name] + if method == 'DELETE': + return {'204': {'description': 'No response body'}} + return {'200': self._get_response_for_code(path, method, serializer)} + else: + schema = { + 'type': 'object', + 'description': 'Unspecified response body', + } + return {'200': self._get_response_for_code(path, method, schema)} + + def _get_response_for_code(self, path, method, serializer): + # TODO: Handle multiple codes and pagination classes. + if not serializer: + return {'description': 'No response body'} + elif isinstance(serializer, serializers.Serializer): + schema = self.resolve_serializer(method, serializer) + if not schema: + return {'description': 'No response body'} + elif isinstance(serializer, dict): + # bypass processing and use given schema directly + schema = serializer + else: + raise ValueError('Serializer type unsupported') if is_list_view(path, method, self.view): - response_schema = { + schema = { 'type': 'array', - 'items': item_schema, + 'items': schema, } paginator = self._get_paginator() if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema + schema = paginator.get_paginated_response_schema(schema) return { - '200': { - 'content': { - ct: {'schema': response_schema} - for ct in self.response_media_types - }, - # description is a mandatory property, - # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject - # TODO: put something meaningful into it - 'description': "" - } + 'content': { + mt: {'schema': schema} for mt in self.map_renderers(path, method) + }, + # description is a mandatory property, + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject + # TODO: put something meaningful into it + 'description': "" } + + def _get_serializer_name(self, method, serializer): + name = serializer.__class__.__name__ + + if name.endswith('Serializer'): + name = name[:-10] + if method == 'PATCH' and not serializer.read_only: + name = 'Patched' + name + + return name + + def resolve_serializer(self, method, serializer): + name = self._get_serializer_name(method, serializer) + + if name not in self.registry.schemas: + # add placeholder to prevent recursion loop + self.registry.schemas[name] = None + + mapped = self._map_serializer(method, serializer) + # empty serializer - usually a transactional serializer. + # no need to put it explicitly in the spec + if not mapped['properties']: + del self.registry.schemas[name] + return {} + else: + self.registry.schemas[name] = mapped + + return {'$ref': '#/components/schemas/{}'.format(name)} diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 03eb9de7a..8e7dffc5e 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -1,14 +1,19 @@ import pytest -from django.conf.urls import url +from django.conf.urls import include, url +from django.db import models from django.test import RequestFactory, TestCase, override_settings from django.utils.translation import gettext_lazy as _ -from rest_framework import filters, generics, pagination, routers, serializers +from rest_framework import ( + filters, generics, pagination, routers, serializers, viewsets +) from rest_framework.compat import uritemplate from rest_framework.parsers import JSONParser, MultiPartParser from rest_framework.renderers import JSONRenderer from rest_framework.request import Request -from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator +from rest_framework.schemas.openapi import ( + AutoSchema, ComponentRegistry, SchemaGenerator +) from . import views @@ -57,7 +62,7 @@ class TestFieldMapping(TestCase): ] for field, mapping in cases: with self.subTest(field=field): - assert inspector._map_field(field) == mapping + assert inspector._map_field('GET', field) == mapping def test_lazy_string_field(self): class Serializer(serializers.Serializer): @@ -65,7 +70,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - data = inspector._map_serializer(Serializer()) + data = inspector._map_serializer('GET', Serializer()) assert isinstance(data['properties']['text']['description'], str), "description must be str" @@ -83,6 +88,7 @@ class TestOperationIntrospection(TestCase): ) inspector = AutoSchema() inspector.view = view + inspector.init(ComponentRegistry()) operation = inspector.get_operation(path, method) assert operation == { @@ -91,17 +97,20 @@ class TestOperationIntrospection(TestCase): 'parameters': [], 'responses': { '200': { - 'description': '', 'content': { 'application/json': { 'schema': { 'type': 'array', - 'items': {}, - }, + 'items': { + 'type': 'object', + 'description': 'Unspecified response body' + } + } }, }, + 'description': '' }, - }, + } } def test_path_with_id_parameter(self): @@ -114,131 +123,154 @@ class TestOperationIntrospection(TestCase): create_request(path) ) inspector = AutoSchema() + inspector.init(ComponentRegistry()) inspector.view = view operation = inspector.get_operation(path, method) assert operation == { 'operationId': 'RetrieveDocStringExampleDetail', 'description': 'A description of my GET operation.', - 'parameters': [{ - 'description': '', - 'in': 'path', - 'name': 'id', - 'required': True, - 'schema': { - 'type': 'string', - }, - }], + 'parameters': [ + { + 'name': 'id', + 'in': 'path', + 'required': True, + 'description': '', + 'schema': { + 'type': 'string' + } + } + ], 'responses': { '200': { - 'description': '', 'content': { 'application/json': { 'schema': { - }, - }, + 'type': 'object', + 'description': 'Unspecified response body' + } + } }, - }, - }, + 'description': '' + } + } } def test_request_body(self): path = '/' method = 'POST' - class Serializer(serializers.Serializer): + class ExampleSerializer(serializers.Serializer): text = serializers.CharField() read_only = serializers.CharField(read_only=True) - class View(generics.GenericAPIView): - serializer_class = Serializer + class View(generics.CreateAPIView): + serializer_class = ExampleSerializer view = create_view( View, method, create_request(path) ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - request_body = inspector._get_request_body(path, method) - assert request_body['content']['application/json']['schema']['required'] == ['text'] - assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text'] + schema = registry.schemas['Example'] + assert schema['required'] == ['text'] + assert schema['properties']['read_only']['readOnly'] is True def test_empty_required(self): path = '/' method = 'POST' - class Serializer(serializers.Serializer): + class ExampleSerializer(serializers.Serializer): read_only = serializers.CharField(read_only=True) write_only = serializers.CharField(write_only=True, required=False) - class View(generics.GenericAPIView): - serializer_class = Serializer + class View(generics.CreateAPIView): + serializer_class = ExampleSerializer view = create_view( View, method, create_request(path) ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - request_body = inspector._get_request_body(path, method) + schema = registry.schemas['Example'] # there should be no empty 'required' property, see #6834 - assert 'required' not in request_body['content']['application/json']['schema'] - - for response in inspector._get_responses(path, method).values(): - assert 'required' not in response['content']['application/json']['schema'] + assert 'required' not in schema def test_empty_required_with_patch_method(self): path = '/' method = 'PATCH' - class Serializer(serializers.Serializer): + class ExampleSerializer(serializers.Serializer): read_only = serializers.CharField(read_only=True) write_only = serializers.CharField(write_only=True, required=False) - class View(generics.GenericAPIView): - serializer_class = Serializer + class View(generics.UpdateAPIView): + serializer_class = ExampleSerializer view = create_view( View, method, create_request(path) ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - request_body = inspector._get_request_body(path, method) + schema = registry.schemas['PatchedExample'] # there should be no empty 'required' property, see #6834 - assert 'required' not in request_body['content']['application/json']['schema'] - for response in inspector._get_responses(path, method).values(): - assert 'required' not in response['content']['application/json']['schema'] + assert 'required' not in schema + for field_schema in schema['properties']: + assert 'required' not in field_schema def test_response_body_generation(self): path = '/' method = 'POST' - class Serializer(serializers.Serializer): + class ExampleSerializer(serializers.Serializer): text = serializers.CharField() write_only = serializers.CharField(write_only=True) - class View(generics.GenericAPIView): - serializer_class = Serializer + class View(generics.CreateAPIView): + serializer_class = ExampleSerializer view = create_view( View, method, create_request(path) ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) - responses = inspector._get_responses(path, method) - assert responses['200']['content']['application/json']['schema']['required'] == ['text'] - assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text'] - assert 'description' in responses['200'] + operation = inspector.get_operation(path, method) + + assert operation['responses'] == { + '200': { + 'content': { + 'application/json': { + 'schema': {'$ref': '#/components/schemas/Example'} + } + }, + 'description': '' + } + } + assert sorted(registry.schemas['Example']['required']) == ['text', 'write_only'] + assert sorted(registry.schemas['Example']['properties'].keys()) == ['text', 'write_only'] def test_response_body_nested_serializer(self): path = '/' @@ -247,28 +279,31 @@ class TestOperationIntrospection(TestCase): class NestedSerializer(serializers.Serializer): number = serializers.IntegerField() - class Serializer(serializers.Serializer): + class ExampleSerializer(serializers.Serializer): text = serializers.CharField() nested = NestedSerializer() - class View(generics.GenericAPIView): - serializer_class = Serializer + class View(generics.CreateAPIView): + serializer_class = ExampleSerializer view = create_view( View, method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) + example_schema = registry.schemas['Example'] + nested_schema = registry.schemas['Nested'] - responses = inspector._get_responses(path, method) - schema = responses['200']['content']['application/json']['schema'] - assert sorted(schema['required']) == ['nested', 'text'] - assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] - assert schema['properties']['nested']['type'] == 'object' - assert list(schema['properties']['nested']['properties'].keys()) == ['number'] - assert schema['properties']['nested']['required'] == ['number'] + assert sorted(example_schema['required']) == ['nested', 'text'] + assert sorted(example_schema['properties'].keys()) == ['nested', 'text'] + assert example_schema['properties']['nested']['type'] == 'object' + assert sorted(nested_schema['properties'].keys()) == ['number'] + assert nested_schema['required'] == ['number'] def test_list_response_body_generation(self): """Test that an array schema is returned for list views.""" @@ -278,7 +313,7 @@ class TestOperationIntrospection(TestCase): class ItemSerializer(serializers.Serializer): text = serializers.CharField() - class View(generics.GenericAPIView): + class View(generics.ListAPIView): serializer_class = ItemSerializer view = create_view( @@ -286,29 +321,25 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) - responses = inspector._get_responses(path, method) - assert responses == { + operation = inspector.get_operation(path, method) + + assert operation['responses'] == { '200': { - 'description': '', 'content': { 'application/json': { 'schema': { 'type': 'array', - 'items': { - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - }, - }, - }, + 'items': {'$ref': '#/components/schemas/Item'}, + } + } }, - }, + 'description': '' + } } def test_paginated_list_response_body_generation(self): @@ -326,7 +357,7 @@ class TestOperationIntrospection(TestCase): class ItemSerializer(serializers.Serializer): text = serializers.CharField() - class View(generics.GenericAPIView): + class View(generics.ListAPIView): serializer_class = ItemSerializer pagination_class = Pagination @@ -337,9 +368,10 @@ class TestOperationIntrospection(TestCase): ) inspector = AutoSchema() inspector.view = view + inspector.init(ComponentRegistry()) - responses = inspector._get_responses(path, method) - assert responses == { + operation = inspector.get_operation(path, method) + assert operation['responses'] == { '200': { 'description': '', 'content': { @@ -348,14 +380,7 @@ class TestOperationIntrospection(TestCase): 'type': 'object', 'item': { 'type': 'array', - 'items': { - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - }, + 'items': {'$ref': '#/components/schemas/Item'}, }, }, }, @@ -378,11 +403,12 @@ class TestOperationIntrospection(TestCase): ) inspector = AutoSchema() inspector.view = view + inspector.init(ComponentRegistry()) - responses = inspector._get_responses(path, method) - assert responses == { + operation = inspector.get_operation(path, method) + assert operation['responses'] == { '204': { - 'description': '', + 'description': 'No response body', }, } @@ -402,19 +428,20 @@ class TestOperationIntrospection(TestCase): ) inspector = AutoSchema() inspector.view = view + inspector.init(ComponentRegistry()) - request_body = inspector._get_request_body(path, method) - - assert len(request_body['content'].keys()) == 2 - assert 'multipart/form-data' in request_body['content'] - assert 'application/json' in request_body['content'] + operation = inspector.get_operation(path, method) + content = operation['requestBody']['content'] + assert len(content.keys()) == 2 + assert 'multipart/form-data' in content + assert 'application/json' in content def test_renderer_mapping(self): """Test that view's renderers are mapped to OA media types""" path = '/{id}/' method = 'GET' - class View(generics.CreateAPIView): + class View(generics.ListCreateAPIView): serializer_class = views.ExampleSerializer renderer_classes = [JSONRenderer] @@ -423,13 +450,15 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) - responses = inspector._get_responses(path, method) + operation = inspector.get_operation(path, method) # TODO this should be changed once the multiple response # schema support is there - success_response = responses['200'] + success_response = operation['responses']['200'] assert len(success_response['content'].keys()) == 1 assert 'application/json' in success_response['content'] @@ -449,13 +478,15 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) - request_body = inspector._get_request_body(path, method) - mp_media = request_body['content']['multipart/form-data'] - attachment = mp_media['schema']['properties']['attachment'] - assert attachment['format'] == 'binary' + operation = inspector.get_operation(path, method) + + assert 'multipart/form-data' in operation['requestBody']['content'] + assert registry.schemas['Item']['properties']['attachment']['format'] == 'binary' def test_retrieve_response_body_generation(self): """ @@ -476,7 +507,7 @@ class TestOperationIntrospection(TestCase): class ItemSerializer(serializers.Serializer): text = serializers.CharField() - class View(generics.GenericAPIView): + class View(generics.RetrieveAPIView): serializer_class = ItemSerializer pagination_class = Pagination @@ -485,26 +516,30 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) - responses = inspector._get_responses(path, method) - assert responses == { + operation = inspector.get_operation(path, method) + + assert operation['responses'] == { '200': { - 'description': '', 'content': { 'application/json': { - 'schema': { - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - }, - }, + 'schema': {'$ref': '#/components/schemas/Item'} + } + }, + 'description': '' + } + } + assert registry.schemas['Item'] == { + 'properties': { + 'text': { + 'type': 'string', }, }, + 'required': ['text'], } def test_operation_id_generation(self): @@ -518,6 +553,7 @@ class TestOperationIntrospection(TestCase): ) inspector = AutoSchema() inspector.view = view + inspector.init(ComponentRegistry()) operationId = inspector._get_operation_id(path, method) assert operationId == 'listExamples' @@ -532,7 +568,6 @@ class TestOperationIntrospection(TestCase): request = create_request('/') schema = generator.get_schema(request=request) schema_str = str(schema) - print(schema_str) assert schema_str.count("operationId") == 2 assert schema_str.count("newExample") == 1 assert schema_str.count("oldExample") == 1 @@ -545,12 +580,13 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + properties = registry.schemas['Example']['properties'] assert properties['date']['type'] == properties['datetime']['type'] == 'string' assert properties['date']['format'] == 'date' assert properties['datetime']['format'] == 'date-time' @@ -563,12 +599,13 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + properties = registry.schemas['Example']['properties'] assert properties['hstore']['type'] == 'object' def test_serializer_callable_default(self): @@ -595,12 +632,13 @@ class TestOperationIntrospection(TestCase): method, create_request(path), ) + registry = ComponentRegistry() inspector = AutoSchema() inspector.view = view + inspector.init(registry) + inspector.get_operation(path, method) - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + properties = registry.schemas['ExampleValidated']['properties'] assert properties['integer']['type'] == 'integer' assert properties['integer']['maximum'] == 99 @@ -643,6 +681,34 @@ class TestOperationIntrospection(TestCase): assert properties['ip']['type'] == 'string' assert 'format' not in properties['ip'] + def test_modelviewset(self): + class ExampleModel(models.Model): + text = models.TextField() + + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = ExampleModel + fields = ['id', 'text'] + + class ExampleViewSet(viewsets.ModelViewSet): + serializer_class = ExampleSerializer + queryset = ExampleModel.objects.none() + + router = routers.DefaultRouter() + router.register(r'example', ExampleViewSet) + + generator = SchemaGenerator(patterns=[ + url(r'api/', include(router.urls)) + ]) + generator._initialise_endpoints() + + schema = generator.get_schema(request=None, public=True) + + assert sorted(schema['paths']['/api/example/'].keys()) == ['get', 'post'] + assert sorted(schema['paths']['/api/example/{id}/'].keys()) == ['delete', 'get', 'patch', 'put'] + assert sorted(schema['components']['schemas'].keys()) == ['Example', 'PatchedExample'] + # TODO do more checks + @pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'}) @@ -659,7 +725,7 @@ class TestGenerator(TestCase): generator = SchemaGenerator(patterns=patterns) generator._initialise_endpoints() - paths = generator.get_paths() + paths = generator.parse() assert '/example/' in paths example_operations = paths['/example/'] @@ -676,7 +742,7 @@ class TestGenerator(TestCase): generator = SchemaGenerator(patterns=patterns) generator._initialise_endpoints() - paths = generator.get_paths() + paths = generator.parse() assert '/v1/example/' in paths assert '/v1/example/{id}/' in paths @@ -689,7 +755,7 @@ class TestGenerator(TestCase): generator = SchemaGenerator(patterns=patterns, url='/api') generator._initialise_endpoints() - paths = generator.get_paths() + paths = generator.parse() assert '/api/example/' in paths assert '/api/example/{id}/' in paths