refactor/extend/improve OpenApi3 spec generation

This commit is contained in:
Thorsten Franzel 2019-12-11 15:01:29 +01:00
parent b135e0fa0a
commit 9486df8d04
3 changed files with 539 additions and 137 deletions

View File

@ -1,6 +1,10 @@
import inspect
import re
import typing
import warnings
from operator import attrgetter
from urllib.parse import urljoin
from uuid import UUID
from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
@ -8,31 +12,69 @@ from django.core.validators import (
)
from django.db import models
from django.utils.encoding import force_str
from django.utils.module_loading import import_string
from rest_framework import exceptions, renderers, serializers
from rest_framework import exceptions, renderers, serializers, permissions
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from rest_framework.schemas.openapi_utils import TYPE_MAPPING, PolymorphicResponse
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator):
AUTHENTICATION_SCHEMES = {
cls.authentication_class: cls
for cls in [import_string(cls) for cls in api_settings.SCHEMA_AUTHENTICATION_CLASSES]
}
def get_info(self):
# Title and version are required by openapi specification 3.x
info = {
'title': self.title or '',
'version': self.version or ''
class ComponentRegistry:
def __init__(self):
self.schemas = {}
self.security_schemes = {}
def get_components(self):
return {
'securitySchemes': self.security_schemes,
'schemas': self.schemas,
}
if self.description is not None:
info['description'] = self.description
return info
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, *args, **kwargs):
self.registry = ComponentRegistry()
super().__init__(*args, **kwargs)
def get_paths(self, request=None):
def create_view(self, callback, method, request=None):
"""
customized create_view which is called when all routes are traversed. part of this
is instatiating views with default params. in case of custom routes (@action) the
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
decorating plain views like retrieve, this initialization logic is not running.
Therefore forcefully set the schema if @extend_schema decorator was used.
"""
view = super().create_view(callback, method, request)
# circumvent import issues by locally importing
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ViewSet
if isinstance(view, GenericViewSet) or isinstance(view, ViewSet):
action = getattr(view, view.action)
elif isinstance(view, APIView):
action = getattr(view, method.lower())
else:
raise RuntimeError('not supported subclass. Must inherit from APIView')
if hasattr(action, 'kwargs') and 'schema' in action.kwargs:
# might already be properly set in case of @action but overwrite for all cases
view.schema = action.kwargs['schema']
return view
def parse(self, request=None):
result = {}
paths, view_endpoints = self._get_paths_and_endpoints(request)
@ -44,7 +86,10 @@ class SchemaGenerator(BaseSchemaGenerator):
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
# keep reference to schema as every access yields a fresh object (descriptor )
schema = view.schema
schema.init(self.registry)
operation = schema.get_operation(path, method)
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
@ -61,20 +106,21 @@ class SchemaGenerator(BaseSchemaGenerator):
"""
self._initialise_endpoints()
paths = self.get_paths(None if public else request)
if not paths:
return None
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
'servers': [
{'url': self.url or 'http://127.0.0.1:8000'},
],
'info': {
'title': self.title or '',
'version': self.version or '0.0.0', # fallback to prevent invalid schema
'description': self.description or '',
},
'paths': self.parse(None if public else request),
'components': self.registry.get_components(),
}
return schema
# View Inspectors
class AutoSchema(ViewInspector):
@ -82,72 +128,117 @@ class AutoSchema(ViewInspector):
response_media_types = []
method_mapping = {
'get': 'Retrieve',
'post': 'Create',
'put': 'Update',
'patch': 'PartialUpdate',
'delete': 'Destroy',
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
def init(self, registry):
self.registry = registry
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self._get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
operation['parameters'] = sorted(
[
*self._get_path_parameters(path, method),
*self._get_filter_parameters(path, method),
*self._get_pagination_parameters(path, method),
*self.get_extra_parameters(path, method),
],
key=lambda p: p.get('name')
)
parameters = []
parameters += self._get_path_parameters(path, method)
parameters += self._get_pagination_parameters(path, method)
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
tags = self.get_tags(path, method)
if tags:
operation['tags'] = tags
request_body = self._get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
auth = self.get_auth(path, method)
if auth:
operation['security'] = auth
self.response_media_types = self.map_renderers(path, method)
operation['responses'] = self._get_response_bodies(path, method)
return operation
def get_extra_parameters(self, path, method):
""" override this for custom behaviour """
return []
def get_description(self, path, method):
""" override this for custom behaviour """
action_or_method = getattr(self.view, getattr(self.view, 'action', method.lower()), None)
view_doc = inspect.getdoc(self.view) or ''
action_doc = inspect.getdoc(action_or_method) or ''
return view_doc + '\n\n' + action_doc if action_doc else view_doc
def get_auth(self, path, method):
""" override this for custom behaviour """
auth = []
if hasattr(self.view, 'authentication_classes'):
auth = [
self.resolve_authentication(method, ac) for ac in self.view.authentication_classes
]
if hasattr(self.view, 'permission_classes'):
perms = self.view.permission_classes
if permissions.AllowAny in perms:
auth.append({})
elif permissions.IsAuthenticatedOrReadOnly in perms and method not in ('PUT', 'PATCH', 'POST'):
auth.append({})
return auth
def get_request_serializer(self, path, method):
""" override this for custom behaviour """
return self._get_serializer(path, method)
def get_response_serializers(self, path, method):
""" override this for custom behaviour """
return self._get_serializer(path, method)
def get_tags(self, path, method):
""" override this for custom behaviour """
path = re.sub(
pattern=api_settings.SCHEMA_PATH_PREFIX,
repl='',
string=path,
flags=re.IGNORECASE
).split('/')
return [path[0]]
def _get_operation_id(self, path, method):
"""
Compute an operation ID from the model, serializer or view name.
"""
method_name = getattr(self.view, 'action', method.lower())
# remove path prefix
sub_path = re.sub(
pattern=api_settings.SCHEMA_PATH_PREFIX,
repl='',
string=path,
flags=re.IGNORECASE
)
# cleanup, normalize and tokenize remaining parts.
# replace dashes as they can be problematic later in code generation
sub_path = sub_path.replace('-', '_').rstrip('/').lstrip('/')
sub_path = sub_path.split('/') if sub_path else []
# remove path variables
sub_path = [p for p in sub_path if not p.startswith('{')]
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = method_name
else:
action = self.method_mapping[method.lower()]
# Try to deduce the ID from the view's model
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if model is not None:
name = model.__name__
# Try with the serializer class name
elif hasattr(self.view, 'get_serializer_class'):
name = self.view.get_serializer_class().__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
# comes at the end of the name
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
name += 's'
return action + name
return '_'.join(sub_path + [action])
def _get_path_parameters(self, path, method):
"""
@ -160,6 +251,8 @@ class AutoSchema(ViewInspector):
for variable in uritemplate.variables(path):
description = ''
schema = TYPE_MAPPING[str]
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
@ -172,14 +265,16 @@ class AutoSchema(ViewInspector):
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
# TODO cover more cases
if isinstance(model_field, models.UUIDField):
schema = TYPE_MAPPING[UUID]
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
'schema': schema,
}
parameters.append(parameter)
@ -218,16 +313,15 @@ class AutoSchema(ViewInspector):
return paginator.get_schema_operation_parameters(view)
def _map_field(self, field):
def _map_field(self, method, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self._map_serializer(field.child)
'items': self.resolve_serializer(method, field.child)
}
if isinstance(field, serializers.Serializer):
data = self._map_serializer(field)
data = self.resolve_serializer(method, field, nested=True)
data['type'] = 'object'
return data
@ -261,7 +355,6 @@ class AutoSchema(ViewInspector):
'enum': list(field.choices),
}
# ListField.
if isinstance(field, serializers.ListField):
mapping = {
'type': 'array',
@ -355,6 +448,10 @@ class AutoSchema(ViewInspector):
'format': 'binary'
}
if isinstance(field, serializers.SerializerMethodField):
method = getattr(field.parent, field.method_name)
return self._map_type_hint(method)
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
@ -370,7 +467,7 @@ class AutoSchema(ViewInspector):
if field.min_value:
content['minimum'] = field.min_value
def _map_serializer(self, serializer):
def _map_serializer(self, method, serializer, nested=False):
# Assuming we have a valid serializer instance.
# TODO:
# - field is Nested or List serializer.
@ -386,7 +483,7 @@ class AutoSchema(ViewInspector):
if field.required:
required.append(field.field_name)
schema = self._map_field(field)
schema = self._map_field(method, field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
@ -404,15 +501,12 @@ class AutoSchema(ViewInspector):
result = {
'properties': properties
}
if required:
if required and method != 'PATCH' and not nested:
result['required'] = required
return result
def _map_field_validators(self, field, schema):
"""
map field validators
"""
for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
@ -446,6 +540,30 @@ class AutoSchema(ViewInspector):
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def _map_type_hint(self, method, hint=None):
if not hint:
hint = typing.get_type_hints(method).get('return')
if hint in TYPE_MAPPING:
return TYPE_MAPPING[hint]
elif hint.__origin__ is typing.Union:
sub_hints = [
self._map_type_hint(method, sub_hint)
for sub_hint in hint.__args__ if sub_hint is not type(None) # noqa
]
if type(None) in hint.__args__ and len(sub_hints) == 1:
return {**sub_hints[0], 'nullable': True}
elif type(None) in hint.__args__:
return {'oneOf': [{**sub_hint, 'nullable': True} for sub_hint in sub_hints]}
else:
return {'oneOf': sub_hints}
else:
warnings.warn(
'type hint for SerializerMethodField function "{}" is unknown. '
'defaulting to string.'.format(method.__name__)
)
return {'type': 'string'}
def _get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
@ -473,82 +591,195 @@ class AutoSchema(ViewInspector):
try:
return view.get_serializer()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
warnings.warn(
'{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'.format(view.__class__.__name__, method, path)
)
return None
def _get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
content.pop('required', None)
# No read_only fields for request.
for name, schema in content['properties'].copy().items():
if 'readOnly' in schema:
del content['properties'][name]
return {
'content': {
ct: {'schema': content}
for ct in self.request_media_types
}
}
def _get_responses(self, path, method):
# TODO: Handle multiple codes and pagination classes.
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
item_schema = {}
serializer = self._get_serializer(path, method)
if isinstance(serializer, serializers.Serializer):
item_schema = self._map_serializer(serializer)
# No write_only fields for response.
for name, schema in item_schema['properties'].copy().items():
if 'writeOnly' in schema:
del item_schema['properties'][name]
if 'required' in item_schema:
item_schema['required'] = [f for f in item_schema['required'] if f != name]
schema = self.resolve_serializer(method, serializer)
else:
warnings.warn(
'could not resolve request body for {} {}. defaulting to generic '
'free-form object. (maybe annotate a Serializer class?)'.format(method, path)
)
schema = {
'type': 'object',
'additionalProperties': {}, # https://github.com/swagger-api/swagger-codegen/issues/1318
'description': 'Unspecified request body',
}
# serializer has no fields so skip content enumeration
if not schema:
return {}
return {
'content': {mt: {'schema': schema} for mt in request_media_types}
}
def _get_response_bodies(self, path, method):
response_serializers = self.get_response_serializers(path, method)
if isinstance(response_serializers, serializers.Serializer) or isinstance(response_serializers, PolymorphicResponse):
if method == 'DELETE':
return {'204': {'description': 'No response body'}}
return {'200': self._get_response_for_code(path, method, response_serializers)}
elif isinstance(response_serializers, dict):
# custom handling for overriding default return codes with @extend_schema
return {
code: self._get_response_for_code(path, method, serializer)
for code, serializer in response_serializers.items()
}
else:
warnings.warn(
'could not resolve response for {} {}. defaulting '
'to generic free-form object.'.format(method, path)
)
schema = {
'type': 'object',
'description': 'Unspecified response body',
}
return {'200': self._get_response_for_code(path, method, schema)}
def _get_response_for_code(self, path, method, serializer_instance):
if not serializer_instance:
return {'description': 'No response body'}
elif isinstance(serializer_instance, serializers.Serializer):
schema = self.resolve_serializer(method, serializer_instance)
if not schema:
return {'description': 'No response body'}
elif isinstance(serializer_instance, PolymorphicResponse):
# custom handling for @extend_schema's injection of polymorphic responses
schemas = []
for serializer in serializer_instance.serializers:
assert isinstance(serializer, serializers.Serializer)
schema_option = self.resolve_serializer(method, serializer)
if schema_option:
schemas.append(schema_option)
schema = {
'oneOf': schemas,
'discriminator': {
'propertyName': serializer_instance.resource_type_field_name
}
}
elif isinstance(serializer_instance, dict):
# bypass processing and use given schema directly
schema = serializer_instance
else:
raise ValueError('Serializer type unsupported')
if is_list_view(path, method, self.view):
response_schema = {
schema = {
'type': 'array',
'items': item_schema,
'items': schema,
}
paginator = self._get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else:
response_schema = item_schema
schema = paginator.get_paginated_response_schema(schema)
return {
'200': {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
'content': {
mt: {'schema': schema} for mt in self.response_media_types
},
# Description is required by spec, but descriptions for each response code don't really
# fit into our model. Description is therefore put into the higher level slots.
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
'description': ''
}
def _get_serializer_name(self, method, serializer, nested):
name = serializer.__class__.__name__
if name.endswith('Serializer'):
name = name[:-10]
if method == 'PATCH' and not nested:
name = 'Patched' + name
return name
def resolve_authentication(self, method, authentication):
if authentication not in AUTHENTICATION_SCHEMES:
raise ValueError()
auth_scheme = AUTHENTICATION_SCHEMES.get(authentication)
if not auth_scheme:
raise ValueError('no auth scheme registered for {}'.format(authentication.__name__))
if auth_scheme.name not in self.registry.security_schemes:
self.registry.security_schemes[auth_scheme.name] = auth_scheme.schema
return {auth_scheme.name: []}
def resolve_serializer(self, method, serializer, nested=False):
name = self._get_serializer_name(method, serializer, nested)
if name not in self.registry.schemas:
# add placeholder to prevent recursion loop
self.registry.schemas[name] = None
mapped = self._map_serializer(method, serializer, nested)
# empty serializer - usually a transactional serializer.
# no need to put it explicitly in the spec
if not mapped['properties']:
del self.registry.schemas[name]
return {}
else:
self.registry.schemas[name] = mapped
return {'$ref': '#/components/schemas/{}'.format(name)}
class PolymorphicAutoSchema(AutoSchema):
"""
"""
def resolve_serializer(self, method, serializer, nested=False):
try:
from rest_polymorphic.serializers import PolymorphicSerializer
except ImportError:
warnings.warn('rest_polymorphic package required for PolymorphicAutoSchema')
raise
if isinstance(serializer, PolymorphicSerializer):
return self._resolve_polymorphic_serializer(method, serializer, nested)
else:
return super().resolve_serializer(method, serializer, nested)
def _resolve_polymorphic_serializer(self, method, serializer, nested):
polymorphic_names = []
for poly_model, poly_serializer in serializer.model_serializer_mapping.items():
name = self._get_serializer_name(method, poly_serializer, nested)
if name not in self.registry.schemas:
# add placeholder to prevent recursion loop
self.registry.schemas[name] = None
# append the type field to serializer fields
mapped = self._map_serializer(method, poly_serializer, nested)
mapped['properties'][serializer.resource_type_field_name] = {'type': 'string'}
self.registry.schemas[name] = mapped
polymorphic_names.append(name)
return {
'oneOf': [
{'$ref': '#/components/schemas/{}'.format(name)} for name in polymorphic_names
],
'discriminator': {
'propertyName': serializer.resource_type_field_name
}
}

View File

@ -0,0 +1,165 @@
import inspect
import warnings
from decimal import Decimal
from uuid import UUID
from datetime import datetime, date
from rest_framework import authentication
from rest_framework.settings import api_settings
VALID_TYPES = ['integer', 'number', 'string', 'boolean']
TYPE_MAPPING = {
float: {'type': 'number', 'format': 'float'},
bool: {'type': 'boolean'},
str: {'type': 'string'},
bytes: {'type': 'string', 'format': 'binary'}, # or byte?
int: {'type': 'integer'},
UUID: {'type': 'string', 'format': 'uuid'},
Decimal: {'type': 'number', 'format': 'double'},
datetime: {'type': 'string', 'format': 'date-time'},
date: {'type': 'string', 'format': 'date'},
None: {},
type(None): {},
}
class OpenApiAuthenticationScheme:
authentication_class = None
name = None
schema = None
class SessionAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.SessionAuthentication
name = 'cookieAuth'
schema = {
'type': 'apiKey',
'in': 'cookie',
'name': 'Session',
}
class BasicAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.BasicAuthentication
name = 'basicAuth'
schema = {
'type': 'http',
'scheme': 'basic',
}
class TokenAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.TokenAuthentication
name = 'tokenAuth'
schema = {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'Token',
}
class PolymorphicResponse:
def __init__(self, serializers, resource_type_field_name):
self.serializers = serializers
self.resource_type_field_name = resource_type_field_name
class OpenApiSchemaBase:
""" reusable base class for objects that can be translated to a schema """
def to_schema(self):
raise NotImplementedError('translation to schema required.')
class QueryParameter(OpenApiSchemaBase):
def __init__(self, name, description='', required=False, type=str):
self.name = name
self.description = description
self.required = required
self.type = type
def to_schema(self):
if self.type not in TYPE_MAPPING:
warnings.warn('{} not a mappable type'.format(self.type))
return {
'name': self.name,
'in': 'query',
'description': self.description,
'required': self.required,
'schema': TYPE_MAPPING.get(self.type)
}
def extend_schema(
operation=None,
extra_parameters=None,
responses=None,
request=None,
auth=None,
description=None,
):
"""
TODO some heavy explaining
:param operation:
:param extra_parameters:
:param responses:
:param request:
:param auth:
:param description:
:return:
"""
def decorator(f):
class ExtendedSchema(api_settings.DEFAULT_SCHEMA_CLASS):
def get_operation(self, path, method):
if operation:
return operation
return super().get_operation(path, method)
def get_extra_parameters(self, path, method):
if extra_parameters:
return [
p.to_schema() if isinstance(p, OpenApiSchemaBase) else p for p in extra_parameters
]
return super().get_extra_parameters(path, method)
def get_auth(self, path, method):
if auth:
return auth
return super().get_auth(path, method)
def get_request_serializer(self, path, method):
if request:
return request
return super().get_request_serializer(path, method)
def get_response_serializers(self, path, method):
if responses:
return responses
return super().get_response_serializers(path, method)
def get_description(self, path, method):
if description:
return description
return super().get_description(path, method)
if inspect.isclass(f):
class ExtendedView(f):
schema = ExtendedSchema()
return ExtendedView
elif callable(f):
# custom actions have kwargs in their context, others don't. create it so our create_view
# implementation can overwrite the default schema
if not hasattr(f, 'kwargs'):
f.kwargs = {}
# this simulates what @action is actually doing. somewhere along the line in this process
# the schema is picked up from kwargs and used. it's involved my dear friends.
f.kwargs['schema'] = ExtendedSchema()
return f
else:
return f
return decorator

View File

@ -53,6 +53,12 @@ DEFAULTS = {
# Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
'SCHEMA_PATH_PREFIX': r'^\/api\/(?:v[0-9\.\_\-]+\/)?',
'SCHEMA_AUTHENTICATION_CLASSES': [
'rest_framework.schemas.openapi_utils.SessionAuthenticationScheme',
'rest_framework.schemas.openapi_utils.BasicAuthenticationScheme',
'rest_framework.schemas.openapi_utils.TokenAuthenticationScheme',
],
# Throttling
'DEFAULT_THROTTLE_RATES': {