refactor/extend/improve OpenApi3 spec generation

This commit is contained in:
Thorsten Franzel 2019-12-11 15:01:29 +01:00
parent b135e0fa0a
commit 9486df8d04
3 changed files with 539 additions and 137 deletions

View File

@ -1,6 +1,10 @@
import inspect
import re
import typing
import warnings import warnings
from operator import attrgetter from operator import attrgetter
from urllib.parse import urljoin from urllib.parse import urljoin
from uuid import UUID
from django.core.validators import ( from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
@ -8,31 +12,69 @@ from django.core.validators import (
) )
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.module_loading import import_string
from rest_framework import exceptions, renderers, serializers from rest_framework import exceptions, renderers, serializers, permissions
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from rest_framework.schemas.openapi_utils import TYPE_MAPPING, PolymorphicResponse
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator): AUTHENTICATION_SCHEMES = {
cls.authentication_class: cls
for cls in [import_string(cls) for cls in api_settings.SCHEMA_AUTHENTICATION_CLASSES]
}
def get_info(self):
# Title and version are required by openapi specification 3.x class ComponentRegistry:
info = { def __init__(self):
'title': self.title or '', self.schemas = {}
'version': self.version or '' self.security_schemes = {}
def get_components(self):
return {
'securitySchemes': self.security_schemes,
'schemas': self.schemas,
} }
if self.description is not None:
info['description'] = self.description
return info class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, *args, **kwargs):
self.registry = ComponentRegistry()
super().__init__(*args, **kwargs)
def get_paths(self, request=None): def create_view(self, callback, method, request=None):
"""
customized create_view which is called when all routes are traversed. part of this
is instatiating views with default params. in case of custom routes (@action) the
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
decorating plain views like retrieve, this initialization logic is not running.
Therefore forcefully set the schema if @extend_schema decorator was used.
"""
view = super().create_view(callback, method, request)
# circumvent import issues by locally importing
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ViewSet
if isinstance(view, GenericViewSet) or isinstance(view, ViewSet):
action = getattr(view, view.action)
elif isinstance(view, APIView):
action = getattr(view, method.lower())
else:
raise RuntimeError('not supported subclass. Must inherit from APIView')
if hasattr(action, 'kwargs') and 'schema' in action.kwargs:
# might already be properly set in case of @action but overwrite for all cases
view.schema = action.kwargs['schema']
return view
def parse(self, request=None):
result = {} result = {}
paths, view_endpoints = self._get_paths_and_endpoints(request) paths, view_endpoints = self._get_paths_and_endpoints(request)
@ -44,7 +86,10 @@ class SchemaGenerator(BaseSchemaGenerator):
for path, method, view in view_endpoints: for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view): if not self.has_view_permissions(path, method, view):
continue continue
operation = view.schema.get_operation(path, method) # keep reference to schema as every access yields a fresh object (descriptor )
schema = view.schema
schema.init(self.registry)
operation = schema.get_operation(path, method)
# Normalise path for any provided mount url. # Normalise path for any provided mount url.
if path.startswith('/'): if path.startswith('/'):
path = path[1:] path = path[1:]
@ -61,20 +106,21 @@ class SchemaGenerator(BaseSchemaGenerator):
""" """
self._initialise_endpoints() self._initialise_endpoints()
paths = self.get_paths(None if public else request)
if not paths:
return None
schema = { schema = {
'openapi': '3.0.2', 'openapi': '3.0.2',
'info': self.get_info(), 'servers': [
'paths': paths, {'url': self.url or 'http://127.0.0.1:8000'},
],
'info': {
'title': self.title or '',
'version': self.version or '0.0.0', # fallback to prevent invalid schema
'description': self.description or '',
},
'paths': self.parse(None if public else request),
'components': self.registry.get_components(),
} }
return schema return schema
# View Inspectors
class AutoSchema(ViewInspector): class AutoSchema(ViewInspector):
@ -82,72 +128,117 @@ class AutoSchema(ViewInspector):
response_media_types = [] response_media_types = []
method_mapping = { method_mapping = {
'get': 'Retrieve', 'get': 'retrieve',
'post': 'Create', 'post': 'create',
'put': 'Update', 'put': 'update',
'patch': 'PartialUpdate', 'patch': 'partial_update',
'delete': 'Destroy', 'delete': 'destroy',
} }
def init(self, registry):
self.registry = registry
def get_operation(self, path, method): def get_operation(self, path, method):
operation = {} operation = {}
operation['operationId'] = self._get_operation_id(path, method) operation['operationId'] = self._get_operation_id(path, method)
operation['description'] = self.get_description(path, method) operation['description'] = self.get_description(path, method)
operation['parameters'] = sorted(
[
*self._get_path_parameters(path, method),
*self._get_filter_parameters(path, method),
*self._get_pagination_parameters(path, method),
*self.get_extra_parameters(path, method),
],
key=lambda p: p.get('name')
)
parameters = [] tags = self.get_tags(path, method)
parameters += self._get_path_parameters(path, method) if tags:
parameters += self._get_pagination_parameters(path, method) operation['tags'] = tags
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self._get_request_body(path, method) request_body = self._get_request_body(path, method)
if request_body: if request_body:
operation['requestBody'] = request_body operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
auth = self.get_auth(path, method)
if auth:
operation['security'] = auth
self.response_media_types = self.map_renderers(path, method)
operation['responses'] = self._get_response_bodies(path, method)
return operation return operation
def get_extra_parameters(self, path, method):
""" override this for custom behaviour """
return []
def get_description(self, path, method):
""" override this for custom behaviour """
action_or_method = getattr(self.view, getattr(self.view, 'action', method.lower()), None)
view_doc = inspect.getdoc(self.view) or ''
action_doc = inspect.getdoc(action_or_method) or ''
return view_doc + '\n\n' + action_doc if action_doc else view_doc
def get_auth(self, path, method):
""" override this for custom behaviour """
auth = []
if hasattr(self.view, 'authentication_classes'):
auth = [
self.resolve_authentication(method, ac) for ac in self.view.authentication_classes
]
if hasattr(self.view, 'permission_classes'):
perms = self.view.permission_classes
if permissions.AllowAny in perms:
auth.append({})
elif permissions.IsAuthenticatedOrReadOnly in perms and method not in ('PUT', 'PATCH', 'POST'):
auth.append({})
return auth
def get_request_serializer(self, path, method):
""" override this for custom behaviour """
return self._get_serializer(path, method)
def get_response_serializers(self, path, method):
""" override this for custom behaviour """
return self._get_serializer(path, method)
def get_tags(self, path, method):
""" override this for custom behaviour """
path = re.sub(
pattern=api_settings.SCHEMA_PATH_PREFIX,
repl='',
string=path,
flags=re.IGNORECASE
).split('/')
return [path[0]]
def _get_operation_id(self, path, method): def _get_operation_id(self, path, method):
""" """
Compute an operation ID from the model, serializer or view name. Compute an operation ID from the model, serializer or view name.
""" """
method_name = getattr(self.view, 'action', method.lower()) # remove path prefix
sub_path = re.sub(
pattern=api_settings.SCHEMA_PATH_PREFIX,
repl='',
string=path,
flags=re.IGNORECASE
)
# cleanup, normalize and tokenize remaining parts.
# replace dashes as they can be problematic later in code generation
sub_path = sub_path.replace('-', '_').rstrip('/').lstrip('/')
sub_path = sub_path.split('/') if sub_path else []
# remove path variables
sub_path = [p for p in sub_path if not p.startswith('{')]
if is_list_view(path, method, self.view): if is_list_view(path, method, self.view):
action = 'list' action = 'list'
elif method_name not in self.method_mapping:
action = method_name
else: else:
action = self.method_mapping[method.lower()] action = self.method_mapping[method.lower()]
# Try to deduce the ID from the view's model return '_'.join(sub_path + [action])
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if model is not None:
name = model.__name__
# Try with the serializer class name
elif hasattr(self.view, 'get_serializer_class'):
name = self.view.get_serializer_class().__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
# comes at the end of the name
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
name += 's'
return action + name
def _get_path_parameters(self, path, method): def _get_path_parameters(self, path, method):
""" """
@ -160,6 +251,8 @@ class AutoSchema(ViewInspector):
for variable in uritemplate.variables(path): for variable in uritemplate.variables(path):
description = '' description = ''
schema = TYPE_MAPPING[str]
if model is not None: # TODO: test this. if model is not None: # TODO: test this.
# Attempt to infer a field description if possible. # Attempt to infer a field description if possible.
try: try:
@ -172,14 +265,16 @@ class AutoSchema(ViewInspector):
elif model_field is not None and model_field.primary_key: elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field) description = get_pk_description(model, model_field)
# TODO cover more cases
if isinstance(model_field, models.UUIDField):
schema = TYPE_MAPPING[UUID]
parameter = { parameter = {
"name": variable, "name": variable,
"in": "path", "in": "path",
"required": True, "required": True,
"description": description, "description": description,
'schema': { 'schema': schema,
'type': 'string', # TODO: integer, pattern, ...
},
} }
parameters.append(parameter) parameters.append(parameter)
@ -218,16 +313,15 @@ class AutoSchema(ViewInspector):
return paginator.get_schema_operation_parameters(view) return paginator.get_schema_operation_parameters(view)
def _map_field(self, field): def _map_field(self, method, field):
# Nested Serializers, `many` or not. # Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer): if isinstance(field, serializers.ListSerializer):
return { return {
'type': 'array', 'type': 'array',
'items': self._map_serializer(field.child) 'items': self.resolve_serializer(method, field.child)
} }
if isinstance(field, serializers.Serializer): if isinstance(field, serializers.Serializer):
data = self._map_serializer(field) data = self.resolve_serializer(method, field, nested=True)
data['type'] = 'object' data['type'] = 'object'
return data return data
@ -261,7 +355,6 @@ class AutoSchema(ViewInspector):
'enum': list(field.choices), 'enum': list(field.choices),
} }
# ListField.
if isinstance(field, serializers.ListField): if isinstance(field, serializers.ListField):
mapping = { mapping = {
'type': 'array', 'type': 'array',
@ -355,6 +448,10 @@ class AutoSchema(ViewInspector):
'format': 'binary' 'format': 'binary'
} }
if isinstance(field, serializers.SerializerMethodField):
method = getattr(field.parent, field.method_name)
return self._map_type_hint(method)
# Simplest cases, default to 'string' type: # Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = { FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean', serializers.BooleanField: 'boolean',
@ -370,7 +467,7 @@ class AutoSchema(ViewInspector):
if field.min_value: if field.min_value:
content['minimum'] = field.min_value content['minimum'] = field.min_value
def _map_serializer(self, serializer): def _map_serializer(self, method, serializer, nested=False):
# Assuming we have a valid serializer instance. # Assuming we have a valid serializer instance.
# TODO: # TODO:
# - field is Nested or List serializer. # - field is Nested or List serializer.
@ -386,7 +483,7 @@ class AutoSchema(ViewInspector):
if field.required: if field.required:
required.append(field.field_name) required.append(field.field_name)
schema = self._map_field(field) schema = self._map_field(method, field)
if field.read_only: if field.read_only:
schema['readOnly'] = True schema['readOnly'] = True
if field.write_only: if field.write_only:
@ -404,15 +501,12 @@ class AutoSchema(ViewInspector):
result = { result = {
'properties': properties 'properties': properties
} }
if required: if required and method != 'PATCH' and not nested:
result['required'] = required result['required'] = required
return result return result
def _map_field_validators(self, field, schema): def _map_field_validators(self, field, schema):
"""
map field validators
"""
for v in field.validators: for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
@ -446,6 +540,30 @@ class AutoSchema(ViewInspector):
schema['maximum'] = int(digits * '9') + 1 schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum'] schema['minimum'] = -schema['maximum']
def _map_type_hint(self, method, hint=None):
if not hint:
hint = typing.get_type_hints(method).get('return')
if hint in TYPE_MAPPING:
return TYPE_MAPPING[hint]
elif hint.__origin__ is typing.Union:
sub_hints = [
self._map_type_hint(method, sub_hint)
for sub_hint in hint.__args__ if sub_hint is not type(None) # noqa
]
if type(None) in hint.__args__ and len(sub_hints) == 1:
return {**sub_hints[0], 'nullable': True}
elif type(None) in hint.__args__:
return {'oneOf': [{**sub_hint, 'nullable': True} for sub_hint in sub_hints]}
else:
return {'oneOf': sub_hints}
else:
warnings.warn(
'type hint for SerializerMethodField function "{}" is unknown. '
'defaulting to string.'.format(method.__name__)
)
return {'type': 'string'}
def _get_paginator(self): def _get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None) pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class: if pagination_class:
@ -473,82 +591,195 @@ class AutoSchema(ViewInspector):
try: try:
return view.get_serializer() return view.get_serializer()
except exceptions.APIException: except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during ' warnings.warn(
'{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be ' 'schema generation. Serializer fields will not be '
'generated for {} {}.' 'generated for {} {}.'.format(view.__class__.__name__, method, path)
.format(view.__class__.__name__, method, path)) )
return None return None
def _get_request_body(self, path, method): def _get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
return {} return {}
self.request_media_types = self.map_parsers(path, method) request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(path, method) serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
content.pop('required', None)
# No read_only fields for request.
for name, schema in content['properties'].copy().items():
if 'readOnly' in schema:
del content['properties'][name]
return {
'content': {
ct: {'schema': content}
for ct in self.request_media_types
}
}
def _get_responses(self, path, method):
# TODO: Handle multiple codes and pagination classes.
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
item_schema = {}
serializer = self._get_serializer(path, method)
if isinstance(serializer, serializers.Serializer): if isinstance(serializer, serializers.Serializer):
item_schema = self._map_serializer(serializer) schema = self.resolve_serializer(method, serializer)
# No write_only fields for response. else:
for name, schema in item_schema['properties'].copy().items(): warnings.warn(
if 'writeOnly' in schema: 'could not resolve request body for {} {}. defaulting to generic '
del item_schema['properties'][name] 'free-form object. (maybe annotate a Serializer class?)'.format(method, path)
if 'required' in item_schema: )
item_schema['required'] = [f for f in item_schema['required'] if f != name] schema = {
'type': 'object',
'additionalProperties': {}, # https://github.com/swagger-api/swagger-codegen/issues/1318
'description': 'Unspecified request body',
}
# serializer has no fields so skip content enumeration
if not schema:
return {}
return {
'content': {mt: {'schema': schema} for mt in request_media_types}
}
def _get_response_bodies(self, path, method):
response_serializers = self.get_response_serializers(path, method)
if isinstance(response_serializers, serializers.Serializer) or isinstance(response_serializers, PolymorphicResponse):
if method == 'DELETE':
return {'204': {'description': 'No response body'}}
return {'200': self._get_response_for_code(path, method, response_serializers)}
elif isinstance(response_serializers, dict):
# custom handling for overriding default return codes with @extend_schema
return {
code: self._get_response_for_code(path, method, serializer)
for code, serializer in response_serializers.items()
}
else:
warnings.warn(
'could not resolve response for {} {}. defaulting '
'to generic free-form object.'.format(method, path)
)
schema = {
'type': 'object',
'description': 'Unspecified response body',
}
return {'200': self._get_response_for_code(path, method, schema)}
def _get_response_for_code(self, path, method, serializer_instance):
if not serializer_instance:
return {'description': 'No response body'}
elif isinstance(serializer_instance, serializers.Serializer):
schema = self.resolve_serializer(method, serializer_instance)
if not schema:
return {'description': 'No response body'}
elif isinstance(serializer_instance, PolymorphicResponse):
# custom handling for @extend_schema's injection of polymorphic responses
schemas = []
for serializer in serializer_instance.serializers:
assert isinstance(serializer, serializers.Serializer)
schema_option = self.resolve_serializer(method, serializer)
if schema_option:
schemas.append(schema_option)
schema = {
'oneOf': schemas,
'discriminator': {
'propertyName': serializer_instance.resource_type_field_name
}
}
elif isinstance(serializer_instance, dict):
# bypass processing and use given schema directly
schema = serializer_instance
else:
raise ValueError('Serializer type unsupported')
if is_list_view(path, method, self.view): if is_list_view(path, method, self.view):
response_schema = { schema = {
'type': 'array', 'type': 'array',
'items': item_schema, 'items': schema,
} }
paginator = self._get_paginator() paginator = self._get_paginator()
if paginator: if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema) schema = paginator.get_paginated_response_schema(schema)
else:
response_schema = item_schema
return { return {
'200': {
'content': { 'content': {
ct: {'schema': response_schema} mt: {'schema': schema} for mt in self.response_media_types
for ct in self.response_media_types
}, },
# description is a mandatory property, # Description is required by spec, but descriptions for each response code don't really
# fit into our model. Description is therefore put into the higher level slots.
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it 'description': ''
'description': "" }
def _get_serializer_name(self, method, serializer, nested):
name = serializer.__class__.__name__
if name.endswith('Serializer'):
name = name[:-10]
if method == 'PATCH' and not nested:
name = 'Patched' + name
return name
def resolve_authentication(self, method, authentication):
if authentication not in AUTHENTICATION_SCHEMES:
raise ValueError()
auth_scheme = AUTHENTICATION_SCHEMES.get(authentication)
if not auth_scheme:
raise ValueError('no auth scheme registered for {}'.format(authentication.__name__))
if auth_scheme.name not in self.registry.security_schemes:
self.registry.security_schemes[auth_scheme.name] = auth_scheme.schema
return {auth_scheme.name: []}
def resolve_serializer(self, method, serializer, nested=False):
name = self._get_serializer_name(method, serializer, nested)
if name not in self.registry.schemas:
# add placeholder to prevent recursion loop
self.registry.schemas[name] = None
mapped = self._map_serializer(method, serializer, nested)
# empty serializer - usually a transactional serializer.
# no need to put it explicitly in the spec
if not mapped['properties']:
del self.registry.schemas[name]
return {}
else:
self.registry.schemas[name] = mapped
return {'$ref': '#/components/schemas/{}'.format(name)}
class PolymorphicAutoSchema(AutoSchema):
"""
"""
def resolve_serializer(self, method, serializer, nested=False):
try:
from rest_polymorphic.serializers import PolymorphicSerializer
except ImportError:
warnings.warn('rest_polymorphic package required for PolymorphicAutoSchema')
raise
if isinstance(serializer, PolymorphicSerializer):
return self._resolve_polymorphic_serializer(method, serializer, nested)
else:
return super().resolve_serializer(method, serializer, nested)
def _resolve_polymorphic_serializer(self, method, serializer, nested):
polymorphic_names = []
for poly_model, poly_serializer in serializer.model_serializer_mapping.items():
name = self._get_serializer_name(method, poly_serializer, nested)
if name not in self.registry.schemas:
# add placeholder to prevent recursion loop
self.registry.schemas[name] = None
# append the type field to serializer fields
mapped = self._map_serializer(method, poly_serializer, nested)
mapped['properties'][serializer.resource_type_field_name] = {'type': 'string'}
self.registry.schemas[name] = mapped
polymorphic_names.append(name)
return {
'oneOf': [
{'$ref': '#/components/schemas/{}'.format(name)} for name in polymorphic_names
],
'discriminator': {
'propertyName': serializer.resource_type_field_name
} }
} }

View File

@ -0,0 +1,165 @@
import inspect
import warnings
from decimal import Decimal
from uuid import UUID
from datetime import datetime, date
from rest_framework import authentication
from rest_framework.settings import api_settings
VALID_TYPES = ['integer', 'number', 'string', 'boolean']
TYPE_MAPPING = {
float: {'type': 'number', 'format': 'float'},
bool: {'type': 'boolean'},
str: {'type': 'string'},
bytes: {'type': 'string', 'format': 'binary'}, # or byte?
int: {'type': 'integer'},
UUID: {'type': 'string', 'format': 'uuid'},
Decimal: {'type': 'number', 'format': 'double'},
datetime: {'type': 'string', 'format': 'date-time'},
date: {'type': 'string', 'format': 'date'},
None: {},
type(None): {},
}
class OpenApiAuthenticationScheme:
authentication_class = None
name = None
schema = None
class SessionAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.SessionAuthentication
name = 'cookieAuth'
schema = {
'type': 'apiKey',
'in': 'cookie',
'name': 'Session',
}
class BasicAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.BasicAuthentication
name = 'basicAuth'
schema = {
'type': 'http',
'scheme': 'basic',
}
class TokenAuthenticationScheme(OpenApiAuthenticationScheme):
authentication_class = authentication.TokenAuthentication
name = 'tokenAuth'
schema = {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'Token',
}
class PolymorphicResponse:
def __init__(self, serializers, resource_type_field_name):
self.serializers = serializers
self.resource_type_field_name = resource_type_field_name
class OpenApiSchemaBase:
""" reusable base class for objects that can be translated to a schema """
def to_schema(self):
raise NotImplementedError('translation to schema required.')
class QueryParameter(OpenApiSchemaBase):
def __init__(self, name, description='', required=False, type=str):
self.name = name
self.description = description
self.required = required
self.type = type
def to_schema(self):
if self.type not in TYPE_MAPPING:
warnings.warn('{} not a mappable type'.format(self.type))
return {
'name': self.name,
'in': 'query',
'description': self.description,
'required': self.required,
'schema': TYPE_MAPPING.get(self.type)
}
def extend_schema(
operation=None,
extra_parameters=None,
responses=None,
request=None,
auth=None,
description=None,
):
"""
TODO some heavy explaining
:param operation:
:param extra_parameters:
:param responses:
:param request:
:param auth:
:param description:
:return:
"""
def decorator(f):
class ExtendedSchema(api_settings.DEFAULT_SCHEMA_CLASS):
def get_operation(self, path, method):
if operation:
return operation
return super().get_operation(path, method)
def get_extra_parameters(self, path, method):
if extra_parameters:
return [
p.to_schema() if isinstance(p, OpenApiSchemaBase) else p for p in extra_parameters
]
return super().get_extra_parameters(path, method)
def get_auth(self, path, method):
if auth:
return auth
return super().get_auth(path, method)
def get_request_serializer(self, path, method):
if request:
return request
return super().get_request_serializer(path, method)
def get_response_serializers(self, path, method):
if responses:
return responses
return super().get_response_serializers(path, method)
def get_description(self, path, method):
if description:
return description
return super().get_description(path, method)
if inspect.isclass(f):
class ExtendedView(f):
schema = ExtendedSchema()
return ExtendedView
elif callable(f):
# custom actions have kwargs in their context, others don't. create it so our create_view
# implementation can overwrite the default schema
if not hasattr(f, 'kwargs'):
f.kwargs = {}
# this simulates what @action is actually doing. somewhere along the line in this process
# the schema is picked up from kwargs and used. it's involved my dear friends.
f.kwargs['schema'] = ExtendedSchema()
return f
else:
return f
return decorator

View File

@ -53,6 +53,12 @@ DEFAULTS = {
# Schema # Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema', 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
'SCHEMA_PATH_PREFIX': r'^\/api\/(?:v[0-9\.\_\-]+\/)?',
'SCHEMA_AUTHENTICATION_CLASSES': [
'rest_framework.schemas.openapi_utils.SessionAuthenticationScheme',
'rest_framework.schemas.openapi_utils.BasicAuthenticationScheme',
'rest_framework.schemas.openapi_utils.TokenAuthenticationScheme',
],
# Throttling # Throttling
'DEFAULT_THROTTLE_RATES': { 'DEFAULT_THROTTLE_RATES': {