Implement OpenAPI Components (#7124)

This commit is contained in:
Martin Desrumaux 2020-03-02 19:35:27 +01:00 committed by GitHub
parent 797518af6d
commit 8aa8be7653
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 347 additions and 86 deletions

View File

@ -316,6 +316,65 @@ class CustomSchema(AutoSchema):
def get_operation_id(self, path, method):
pass
class MyView(APIView):
schema = AutoSchema(component_name="Ulysses")
```
### Components
Since DRF 3.12, Schema uses the [OpenAPI Components](openapi-components). This method defines components in the schema and [references them](openapi-reference) inside request and response objects. By default, the component's name is deduced from the Serializer's name.
Using OpenAPI's components provides the following advantages:
* The schema is more readable and lightweight.
* If you use the schema to generate an SDK (using [openapi-generator](openapi-generator) or [swagger-codegen](swagger-codegen)). The generator can name your SDK's models.
### Handling component's schema errors
You may get the following error while generating the schema:
```
"Serializer" is an invalid class name for schema generation.
Serializer's class name should be unique and explicit. e.g. "ItemSerializer".
```
This error occurs when the Serializer name is "Serializer". You should choose a component's name unique across your schema and different than "Serializer".
You may also get the following warning:
```
Schema component "ComponentName" has been overriden with a different value.
```
This warning occurs when different components have the same name in one schema. Your component name should be unique across your project. This is likely an error that may lead to an invalid schema.
You have two ways to solve the previous issues:
* You can rename your serializer with a unique name and another name than "Serializer".
* You can set the `component_name` kwarg parameter of the AutoSchema constructor (see below).
* You can override the `get_component_name` method of the AutoSchema class (see below).
#### Set a custom component's name for your view
To override the component's name in your view, you can use the `component_name` parameter of the AutoSchema constructor:
```python
from rest_framework.schemas.openapi import AutoSchema
class MyView(APIView):
schema = AutoSchema(component_name="Ulysses")
```
#### Override the default implementation
If you want to have more control and customization about how the schema's components are generated, you can override the `get_component_name` and `get_components` method from the AutoSchema class.
```python
from rest_framework.schemas.openapi import AutoSchema
class CustomSchema(AutoSchema):
def get_components(self, path, method):
# Implement your custom implementation
def get_component_name(self, serializer):
# Implement your custom implementation
class CustomView(APIView):
"""APIView subclass with custom schema introspection."""
schema = CustomSchema()
@ -326,3 +385,7 @@ class CustomView(APIView):
[openapi-operation]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#operationObject
[openapi-tags]: https://swagger.io/specification/#tagObject
[openapi-operationid]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#fixed-fields-17
[openapi-components]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#componentsObject
[openapi-reference]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#referenceObject
[openapi-generator]: https://github.com/OpenAPITools/openapi-generator
[swagger-codegen]: https://github.com/swagger-api/swagger-codegen

View File

@ -1,3 +1,4 @@
import re
import warnings
from collections import OrderedDict
from decimal import Decimal
@ -65,9 +66,9 @@ class SchemaGenerator(BaseSchemaGenerator):
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
components_schemas = {}
# Iterate endpoints generating per method path operations.
# TODO: …and reference components.
paths = {}
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
@ -75,6 +76,16 @@ class SchemaGenerator(BaseSchemaGenerator):
continue
operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)
for k in components.keys():
if k not in components_schemas:
continue
if components_schemas[k] == components[k]:
continue
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
components_schemas.update(components)
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
@ -92,6 +103,11 @@ class SchemaGenerator(BaseSchemaGenerator):
'paths': paths,
}
if len(components_schemas) > 0:
schema['components'] = {
'schemas': components_schemas
}
return schema
# View Inspectors
@ -99,14 +115,16 @@ class SchemaGenerator(BaseSchemaGenerator):
class AutoSchema(ViewInspector):
def __init__(self, operation_id_base=None, tags=None):
def __init__(self, tags=None, operation_id_base=None, component_name=None):
"""
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
"""
if tags and not all(isinstance(tag, str) for tag in tags):
raise ValueError('tags must be a list or tuple of string.')
self._tags = tags
self.operation_id_base = operation_id_base
self.component_name = component_name
super().__init__()
request_media_types = []
@ -140,6 +158,43 @@ class AutoSchema(ViewInspector):
return operation
def get_component_name(self, serializer):
"""
Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
"""
if self.component_name is not None:
return self.component_name
# use the serializer's class name as the component name.
component_name = serializer.__class__.__name__
# We remove the "serializer" string from the class name.
pattern = re.compile("serializer", re.IGNORECASE)
component_name = pattern.sub("", component_name)
if component_name == "":
raise Exception(
'"{}" is an invalid class name for schema generation. '
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
.format(serializer.__class__.__name__)
)
return component_name
def get_components(self, path, method):
"""
Return components with their properties from the serializer.
"""
serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
return {}
component_name = self.get_component_name(serializer)
content = self._map_serializer(serializer)
return {component_name: content}
def get_operation_id_base(self, path, method, action):
"""
Compute the base part for operation ID from the model, serializer or view name.
@ -434,10 +489,6 @@ class AutoSchema(ViewInspector):
def _map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
# TODO:
# - field is Nested or List serializer.
# - Handle read_only/write_only for request/response differences.
# - could do this with readOnly/writeOnly and then filter dict.
required = []
properties = {}
@ -542,6 +593,9 @@ class AutoSchema(ViewInspector):
.format(view.__class__.__name__, method, path))
return None
def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def _get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
@ -551,20 +605,13 @@ class AutoSchema(ViewInspector):
serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
content.pop('required', None)
# No read_only fields for request.
for name, schema in content['properties'].copy().items():
if 'readOnly' in schema:
del content['properties'][name]
item_schema = {}
else:
item_schema = self._get_reference(serializer)
return {
'content': {
ct: {'schema': content}
ct: {'schema': item_schema}
for ct in self.request_media_types
}
}
@ -580,17 +627,12 @@ class AutoSchema(ViewInspector):
self.response_media_types = self.map_renderers(path, method)
item_schema = {}
serializer = self._get_serializer(path, method)
if isinstance(serializer, serializers.Serializer):
item_schema = self._map_serializer(serializer)
# No write_only fields for response.
for name, schema in item_schema['properties'].copy().items():
if 'writeOnly' in schema:
del item_schema['properties'][name]
if 'required' in item_schema:
item_schema['required'] = [f for f in item_schema['required'] if f != name]
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {

View File

@ -85,12 +85,12 @@ class TestFieldMapping(TestCase):
assert inspector._map_field(field) == mapping
def test_lazy_string_field(self):
class Serializer(serializers.Serializer):
class ItemSerializer(serializers.Serializer):
text = serializers.CharField(help_text=_('lazy string'))
inspector = AutoSchema()
data = inspector._map_serializer(Serializer())
data = inspector._map_serializer(ItemSerializer())
assert isinstance(data['properties']['text']['description'], str), "description must be str"
def test_boolean_default_field(self):
@ -186,6 +186,33 @@ class TestOperationIntrospection(TestCase):
path = '/'
method = 'POST'
class ItemSerializer(serializers.Serializer):
text = serializers.CharField()
read_only = serializers.CharField(read_only=True)
class View(generics.GenericAPIView):
serializer_class = ItemSerializer
view = create_view(
View,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
print(request_body)
assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method)
assert components['Item']['required'] == ['text']
assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text']
def test_invalid_serializer_class_name(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
text = serializers.CharField()
read_only = serializers.CharField(read_only=True)
@ -201,20 +228,22 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
assert request_body['content']['application/json']['schema']['required'] == ['text']
assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text']
serializer = inspector._get_serializer(path, method)
with pytest.raises(Exception) as exc:
inspector.get_component_name(serializer)
assert "is an invalid class name for schema generation" in str(exc.value)
def test_empty_required(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
class ItemSerializer(serializers.Serializer):
read_only = serializers.CharField(read_only=True)
write_only = serializers.CharField(write_only=True, required=False)
class View(generics.GenericAPIView):
serializer_class = Serializer
serializer_class = ItemSerializer
view = create_view(
View,
@ -224,23 +253,24 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
components = inspector.get_components(path, method)
component = components['Item']
# there should be no empty 'required' property, see #6834
assert 'required' not in request_body['content']['application/json']['schema']
assert 'required' not in component
for response in inspector._get_responses(path, method).values():
assert 'required' not in response['content']['application/json']['schema']
assert 'required' not in component
def test_empty_required_with_patch_method(self):
path = '/'
method = 'PATCH'
class Serializer(serializers.Serializer):
class ItemSerializer(serializers.Serializer):
read_only = serializers.CharField(read_only=True)
write_only = serializers.CharField(write_only=True, required=False)
class View(generics.GenericAPIView):
serializer_class = Serializer
serializer_class = ItemSerializer
view = create_view(
View,
@ -250,22 +280,23 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
components = inspector.get_components(path, method)
component = components['Item']
# there should be no empty 'required' property, see #6834
assert 'required' not in request_body['content']['application/json']['schema']
assert 'required' not in component
for response in inspector._get_responses(path, method).values():
assert 'required' not in response['content']['application/json']['schema']
assert 'required' not in component
def test_response_body_generation(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
class ItemSerializer(serializers.Serializer):
text = serializers.CharField()
write_only = serializers.CharField(write_only=True)
class View(generics.GenericAPIView):
serializer_class = Serializer
serializer_class = ItemSerializer
view = create_view(
View,
@ -276,9 +307,11 @@ class TestOperationIntrospection(TestCase):
inspector.view = view
responses = inspector._get_responses(path, method)
assert '201' in responses
assert responses['201']['content']['application/json']['schema']['required'] == ['text']
assert list(responses['201']['content']['application/json']['schema']['properties'].keys()) == ['text']
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method)
assert sorted(components['Item']['required']) == ['text', 'write_only']
assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only']
assert 'description' in responses['201']
def test_response_body_nested_serializer(self):
@ -288,12 +321,12 @@ class TestOperationIntrospection(TestCase):
class NestedSerializer(serializers.Serializer):
number = serializers.IntegerField()
class Serializer(serializers.Serializer):
class ItemSerializer(serializers.Serializer):
text = serializers.CharField()
nested = NestedSerializer()
class View(generics.GenericAPIView):
serializer_class = Serializer
serializer_class = ItemSerializer
view = create_view(
View,
@ -304,7 +337,11 @@ class TestOperationIntrospection(TestCase):
inspector.view = view
responses = inspector._get_responses(path, method)
schema = responses['201']['content']['application/json']['schema']
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method)
assert components['Item']
schema = components['Item']
assert sorted(schema['required']) == ['nested', 'text']
assert sorted(list(schema['properties'].keys())) == ['nested', 'text']
assert schema['properties']['nested']['type'] == 'object'
@ -339,19 +376,25 @@ class TestOperationIntrospection(TestCase):
'schema': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
'$ref': '#/components/schemas/Item'
},
},
},
},
},
}
components = inspector.get_components(path, method)
assert components == {
'Item': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
}
}
def test_paginated_list_response_body_generation(self):
"""Test that pagination properties are added for a paginated list view."""
@ -391,13 +434,7 @@ class TestOperationIntrospection(TestCase):
'item': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
'$ref': '#/components/schemas/Item'
},
},
},
@ -405,6 +442,18 @@ class TestOperationIntrospection(TestCase):
},
},
}
components = inspector.get_components(path, method)
assert components == {
'Item': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
}
}
def test_delete_response_body_generation(self):
"""Test that a view's delete method generates a proper response body schema."""
@ -508,10 +557,10 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
mp_media = request_body['content']['multipart/form-data']
attachment = mp_media['schema']['properties']['attachment']
assert attachment['format'] == 'binary'
components = inspector.get_components(path, method)
component = components['Item']
properties = component['properties']
assert properties['attachment']['format'] == 'binary'
def test_retrieve_response_body_generation(self):
"""
@ -551,19 +600,26 @@ class TestOperationIntrospection(TestCase):
'content': {
'application/json': {
'schema': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
'$ref': '#/components/schemas/Item'
},
},
},
},
}
components = inspector.get_components(path, method)
assert components == {
'Item': {
'type': 'object',
'properties': {
'text': {
'type': 'string',
},
},
'required': ['text'],
}
}
def test_operation_id_generation(self):
path = '/'
method = 'GET'
@ -689,9 +745,9 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']
properties = response_schema['items']['properties']
components = inspector.get_components(path, method)
component = components['Example']
properties = component['properties']
assert properties['date']['type'] == properties['datetime']['type'] == 'string'
assert properties['date']['format'] == 'date'
assert properties['datetime']['format'] == 'date-time'
@ -707,9 +763,9 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']
properties = response_schema['items']['properties']
components = inspector.get_components(path, method)
component = components['Example']
properties = component['properties']
assert properties['hstore']['type'] == 'object'
def test_serializer_callable_default(self):
@ -723,9 +779,9 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']
properties = response_schema['items']['properties']
components = inspector.get_components(path, method)
component = components['Example']
properties = component['properties']
assert 'default' not in properties['uuid_field']
def test_serializer_validators(self):
@ -739,9 +795,9 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']
properties = response_schema['items']['properties']
components = inspector.get_components(path, method)
component = components['ExampleValidated']
properties = component['properties']
assert properties['integer']['type'] == 'integer'
assert properties['integer']['maximum'] == 99
@ -819,6 +875,7 @@ class TestOperationIntrospection(TestCase):
def test_auto_generated_apiview_tags(self):
class RestaurantAPIView(views.ExampleGenericAPIView):
schema = AutoSchema(operation_id_base="restaurant")
pass
class BranchAPIView(views.ExampleGenericAPIView):
@ -932,3 +989,54 @@ class TestGenerator(TestCase):
assert schema['info']['title'] == ''
assert schema['info']['version'] == ''
def test_serializer_model(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
print(schema)
assert 'components' in schema
assert 'schemas' in schema['components']
assert 'ExampleModel' in schema['components']['schemas']
def test_component_name(self):
patterns = [
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
print(schema)
assert 'components' in schema
assert 'schemas' in schema['components']
assert 'Ulysses' in schema['components']['schemas']
def test_duplicate_component_name(self):
patterns = [
url(r'^duplicate1/?$', views.ExampleAutoSchemaDuplicate1.as_view()),
url(r'^duplicate2/?$', views.ExampleAutoSchemaDuplicate2.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
schema = generator.get_schema(request=request)
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)
assert 'has been overriden with a different value.' in str(w[-1].message)
assert 'components' in schema
assert 'schemas' in schema['components']
assert 'Duplicate' in schema['components']['schemas']

View File

@ -9,6 +9,7 @@ from django.db import models
from rest_framework import generics, permissions, serializers
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.schemas.openapi import AutoSchema
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet
@ -167,3 +168,50 @@ class ExampleOperationIdDuplicate2(generics.GenericAPIView):
def get(self, *args, **kwargs):
pass
class ExampleGenericAPIViewModel(generics.GenericAPIView):
serializer_class = ExampleSerializerModel
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
class ExampleAutoSchemaComponentName(generics.GenericAPIView):
serializer_class = ExampleSerializerModel
schema = AutoSchema(component_name="Ulysses")
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
class ExampleAutoSchemaDuplicate1(generics.GenericAPIView):
serializer_class = ExampleValidatedSerializer
schema = AutoSchema(component_name="Duplicate")
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
class ExampleAutoSchemaDuplicate2(generics.GenericAPIView):
serializer_class = ExampleSerializerModel
schema = AutoSchema(component_name="Duplicate")
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)