mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-28 17:09:59 +03:00
refactor/extend/improve OpenApi3 spec generation
This commit is contained in:
parent
b135e0fa0a
commit
9486df8d04
|
@ -1,6 +1,10 @@
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from django.core.validators import (
|
from django.core.validators import (
|
||||||
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||||
|
@ -8,31 +12,69 @@ from django.core.validators import (
|
||||||
)
|
)
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
from django.utils.module_loading import import_string
|
||||||
|
|
||||||
from rest_framework import exceptions, renderers, serializers
|
from rest_framework import exceptions, renderers, serializers, permissions
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
from rest_framework.fields import _UnvalidatedField, empty
|
from rest_framework.fields import _UnvalidatedField, empty
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.schemas.openapi_utils import TYPE_MAPPING, PolymorphicResponse
|
||||||
from .generators import BaseSchemaGenerator
|
from .generators import BaseSchemaGenerator
|
||||||
from .inspectors import ViewInspector
|
from .inspectors import ViewInspector
|
||||||
from .utils import get_pk_description, is_list_view
|
from .utils import get_pk_description, is_list_view
|
||||||
|
|
||||||
|
|
||||||
class SchemaGenerator(BaseSchemaGenerator):
|
AUTHENTICATION_SCHEMES = {
|
||||||
|
cls.authentication_class: cls
|
||||||
|
for cls in [import_string(cls) for cls in api_settings.SCHEMA_AUTHENTICATION_CLASSES]
|
||||||
|
}
|
||||||
|
|
||||||
def get_info(self):
|
|
||||||
# Title and version are required by openapi specification 3.x
|
class ComponentRegistry:
|
||||||
info = {
|
def __init__(self):
|
||||||
'title': self.title or '',
|
self.schemas = {}
|
||||||
'version': self.version or ''
|
self.security_schemes = {}
|
||||||
|
|
||||||
|
def get_components(self):
|
||||||
|
return {
|
||||||
|
'securitySchemes': self.security_schemes,
|
||||||
|
'schemas': self.schemas,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.description is not None:
|
|
||||||
info['description'] = self.description
|
|
||||||
|
|
||||||
return info
|
class SchemaGenerator(BaseSchemaGenerator):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.registry = ComponentRegistry()
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def get_paths(self, request=None):
|
def create_view(self, callback, method, request=None):
|
||||||
|
"""
|
||||||
|
customized create_view which is called when all routes are traversed. part of this
|
||||||
|
is instatiating views with default params. in case of custom routes (@action) the
|
||||||
|
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
|
||||||
|
decorating plain views like retrieve, this initialization logic is not running.
|
||||||
|
Therefore forcefully set the schema if @extend_schema decorator was used.
|
||||||
|
"""
|
||||||
|
view = super().create_view(callback, method, request)
|
||||||
|
|
||||||
|
# circumvent import issues by locally importing
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
from rest_framework.viewsets import GenericViewSet, ViewSet
|
||||||
|
|
||||||
|
if isinstance(view, GenericViewSet) or isinstance(view, ViewSet):
|
||||||
|
action = getattr(view, view.action)
|
||||||
|
elif isinstance(view, APIView):
|
||||||
|
action = getattr(view, method.lower())
|
||||||
|
else:
|
||||||
|
raise RuntimeError('not supported subclass. Must inherit from APIView')
|
||||||
|
|
||||||
|
if hasattr(action, 'kwargs') and 'schema' in action.kwargs:
|
||||||
|
# might already be properly set in case of @action but overwrite for all cases
|
||||||
|
view.schema = action.kwargs['schema']
|
||||||
|
|
||||||
|
return view
|
||||||
|
|
||||||
|
def parse(self, request=None):
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||||
|
@ -44,7 +86,10 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
for path, method, view in view_endpoints:
|
for path, method, view in view_endpoints:
|
||||||
if not self.has_view_permissions(path, method, view):
|
if not self.has_view_permissions(path, method, view):
|
||||||
continue
|
continue
|
||||||
operation = view.schema.get_operation(path, method)
|
# keep reference to schema as every access yields a fresh object (descriptor )
|
||||||
|
schema = view.schema
|
||||||
|
schema.init(self.registry)
|
||||||
|
operation = schema.get_operation(path, method)
|
||||||
# Normalise path for any provided mount url.
|
# Normalise path for any provided mount url.
|
||||||
if path.startswith('/'):
|
if path.startswith('/'):
|
||||||
path = path[1:]
|
path = path[1:]
|
||||||
|
@ -61,20 +106,21 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
"""
|
"""
|
||||||
self._initialise_endpoints()
|
self._initialise_endpoints()
|
||||||
|
|
||||||
paths = self.get_paths(None if public else request)
|
|
||||||
if not paths:
|
|
||||||
return None
|
|
||||||
|
|
||||||
schema = {
|
schema = {
|
||||||
'openapi': '3.0.2',
|
'openapi': '3.0.2',
|
||||||
'info': self.get_info(),
|
'servers': [
|
||||||
'paths': paths,
|
{'url': self.url or 'http://127.0.0.1:8000'},
|
||||||
|
],
|
||||||
|
'info': {
|
||||||
|
'title': self.title or '',
|
||||||
|
'version': self.version or '0.0.0', # fallback to prevent invalid schema
|
||||||
|
'description': self.description or '',
|
||||||
|
},
|
||||||
|
'paths': self.parse(None if public else request),
|
||||||
|
'components': self.registry.get_components(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
# View Inspectors
|
|
||||||
|
|
||||||
|
|
||||||
class AutoSchema(ViewInspector):
|
class AutoSchema(ViewInspector):
|
||||||
|
|
||||||
|
@ -82,72 +128,117 @@ class AutoSchema(ViewInspector):
|
||||||
response_media_types = []
|
response_media_types = []
|
||||||
|
|
||||||
method_mapping = {
|
method_mapping = {
|
||||||
'get': 'Retrieve',
|
'get': 'retrieve',
|
||||||
'post': 'Create',
|
'post': 'create',
|
||||||
'put': 'Update',
|
'put': 'update',
|
||||||
'patch': 'PartialUpdate',
|
'patch': 'partial_update',
|
||||||
'delete': 'Destroy',
|
'delete': 'destroy',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def init(self, registry):
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
def get_operation(self, path, method):
|
def get_operation(self, path, method):
|
||||||
operation = {}
|
operation = {}
|
||||||
|
|
||||||
operation['operationId'] = self._get_operation_id(path, method)
|
operation['operationId'] = self._get_operation_id(path, method)
|
||||||
operation['description'] = self.get_description(path, method)
|
operation['description'] = self.get_description(path, method)
|
||||||
|
operation['parameters'] = sorted(
|
||||||
|
[
|
||||||
|
*self._get_path_parameters(path, method),
|
||||||
|
*self._get_filter_parameters(path, method),
|
||||||
|
*self._get_pagination_parameters(path, method),
|
||||||
|
*self.get_extra_parameters(path, method),
|
||||||
|
],
|
||||||
|
key=lambda p: p.get('name')
|
||||||
|
)
|
||||||
|
|
||||||
parameters = []
|
tags = self.get_tags(path, method)
|
||||||
parameters += self._get_path_parameters(path, method)
|
if tags:
|
||||||
parameters += self._get_pagination_parameters(path, method)
|
operation['tags'] = tags
|
||||||
parameters += self._get_filter_parameters(path, method)
|
|
||||||
operation['parameters'] = parameters
|
|
||||||
|
|
||||||
request_body = self._get_request_body(path, method)
|
request_body = self._get_request_body(path, method)
|
||||||
if request_body:
|
if request_body:
|
||||||
operation['requestBody'] = request_body
|
operation['requestBody'] = request_body
|
||||||
operation['responses'] = self._get_responses(path, method)
|
|
||||||
|
auth = self.get_auth(path, method)
|
||||||
|
if auth:
|
||||||
|
operation['security'] = auth
|
||||||
|
|
||||||
|
self.response_media_types = self.map_renderers(path, method)
|
||||||
|
|
||||||
|
operation['responses'] = self._get_response_bodies(path, method)
|
||||||
|
|
||||||
return operation
|
return operation
|
||||||
|
|
||||||
|
def get_extra_parameters(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_description(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
action_or_method = getattr(self.view, getattr(self.view, 'action', method.lower()), None)
|
||||||
|
view_doc = inspect.getdoc(self.view) or ''
|
||||||
|
action_doc = inspect.getdoc(action_or_method) or ''
|
||||||
|
return view_doc + '\n\n' + action_doc if action_doc else view_doc
|
||||||
|
|
||||||
|
def get_auth(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
auth = []
|
||||||
|
if hasattr(self.view, 'authentication_classes'):
|
||||||
|
auth = [
|
||||||
|
self.resolve_authentication(method, ac) for ac in self.view.authentication_classes
|
||||||
|
]
|
||||||
|
if hasattr(self.view, 'permission_classes'):
|
||||||
|
perms = self.view.permission_classes
|
||||||
|
if permissions.AllowAny in perms:
|
||||||
|
auth.append({})
|
||||||
|
elif permissions.IsAuthenticatedOrReadOnly in perms and method not in ('PUT', 'PATCH', 'POST'):
|
||||||
|
auth.append({})
|
||||||
|
return auth
|
||||||
|
|
||||||
|
def get_request_serializer(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
return self._get_serializer(path, method)
|
||||||
|
|
||||||
|
def get_response_serializers(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
return self._get_serializer(path, method)
|
||||||
|
|
||||||
|
def get_tags(self, path, method):
|
||||||
|
""" override this for custom behaviour """
|
||||||
|
path = re.sub(
|
||||||
|
pattern=api_settings.SCHEMA_PATH_PREFIX,
|
||||||
|
repl='',
|
||||||
|
string=path,
|
||||||
|
flags=re.IGNORECASE
|
||||||
|
).split('/')
|
||||||
|
return [path[0]]
|
||||||
|
|
||||||
def _get_operation_id(self, path, method):
|
def _get_operation_id(self, path, method):
|
||||||
"""
|
"""
|
||||||
Compute an operation ID from the model, serializer or view name.
|
Compute an operation ID from the model, serializer or view name.
|
||||||
"""
|
"""
|
||||||
method_name = getattr(self.view, 'action', method.lower())
|
# remove path prefix
|
||||||
|
sub_path = re.sub(
|
||||||
|
pattern=api_settings.SCHEMA_PATH_PREFIX,
|
||||||
|
repl='',
|
||||||
|
string=path,
|
||||||
|
flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
# cleanup, normalize and tokenize remaining parts.
|
||||||
|
# replace dashes as they can be problematic later in code generation
|
||||||
|
sub_path = sub_path.replace('-', '_').rstrip('/').lstrip('/')
|
||||||
|
sub_path = sub_path.split('/') if sub_path else []
|
||||||
|
# remove path variables
|
||||||
|
sub_path = [p for p in sub_path if not p.startswith('{')]
|
||||||
|
|
||||||
if is_list_view(path, method, self.view):
|
if is_list_view(path, method, self.view):
|
||||||
action = 'list'
|
action = 'list'
|
||||||
elif method_name not in self.method_mapping:
|
|
||||||
action = method_name
|
|
||||||
else:
|
else:
|
||||||
action = self.method_mapping[method.lower()]
|
action = self.method_mapping[method.lower()]
|
||||||
|
|
||||||
# Try to deduce the ID from the view's model
|
return '_'.join(sub_path + [action])
|
||||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
|
||||||
if model is not None:
|
|
||||||
name = model.__name__
|
|
||||||
|
|
||||||
# Try with the serializer class name
|
|
||||||
elif hasattr(self.view, 'get_serializer_class'):
|
|
||||||
name = self.view.get_serializer_class().__name__
|
|
||||||
if name.endswith('Serializer'):
|
|
||||||
name = name[:-10]
|
|
||||||
|
|
||||||
# Fallback to the view name
|
|
||||||
else:
|
|
||||||
name = self.view.__class__.__name__
|
|
||||||
if name.endswith('APIView'):
|
|
||||||
name = name[:-7]
|
|
||||||
elif name.endswith('View'):
|
|
||||||
name = name[:-4]
|
|
||||||
|
|
||||||
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
|
|
||||||
# comes at the end of the name
|
|
||||||
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
|
|
||||||
name = name[:-len(action)]
|
|
||||||
|
|
||||||
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
|
|
||||||
name += 's'
|
|
||||||
|
|
||||||
return action + name
|
|
||||||
|
|
||||||
def _get_path_parameters(self, path, method):
|
def _get_path_parameters(self, path, method):
|
||||||
"""
|
"""
|
||||||
|
@ -160,6 +251,8 @@ class AutoSchema(ViewInspector):
|
||||||
|
|
||||||
for variable in uritemplate.variables(path):
|
for variable in uritemplate.variables(path):
|
||||||
description = ''
|
description = ''
|
||||||
|
schema = TYPE_MAPPING[str]
|
||||||
|
|
||||||
if model is not None: # TODO: test this.
|
if model is not None: # TODO: test this.
|
||||||
# Attempt to infer a field description if possible.
|
# Attempt to infer a field description if possible.
|
||||||
try:
|
try:
|
||||||
|
@ -172,14 +265,16 @@ class AutoSchema(ViewInspector):
|
||||||
elif model_field is not None and model_field.primary_key:
|
elif model_field is not None and model_field.primary_key:
|
||||||
description = get_pk_description(model, model_field)
|
description = get_pk_description(model, model_field)
|
||||||
|
|
||||||
|
# TODO cover more cases
|
||||||
|
if isinstance(model_field, models.UUIDField):
|
||||||
|
schema = TYPE_MAPPING[UUID]
|
||||||
|
|
||||||
parameter = {
|
parameter = {
|
||||||
"name": variable,
|
"name": variable,
|
||||||
"in": "path",
|
"in": "path",
|
||||||
"required": True,
|
"required": True,
|
||||||
"description": description,
|
"description": description,
|
||||||
'schema': {
|
'schema': schema,
|
||||||
'type': 'string', # TODO: integer, pattern, ...
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
parameters.append(parameter)
|
parameters.append(parameter)
|
||||||
|
|
||||||
|
@ -218,16 +313,15 @@ class AutoSchema(ViewInspector):
|
||||||
|
|
||||||
return paginator.get_schema_operation_parameters(view)
|
return paginator.get_schema_operation_parameters(view)
|
||||||
|
|
||||||
def _map_field(self, field):
|
def _map_field(self, method, field):
|
||||||
|
|
||||||
# Nested Serializers, `many` or not.
|
# Nested Serializers, `many` or not.
|
||||||
if isinstance(field, serializers.ListSerializer):
|
if isinstance(field, serializers.ListSerializer):
|
||||||
return {
|
return {
|
||||||
'type': 'array',
|
'type': 'array',
|
||||||
'items': self._map_serializer(field.child)
|
'items': self.resolve_serializer(method, field.child)
|
||||||
}
|
}
|
||||||
if isinstance(field, serializers.Serializer):
|
if isinstance(field, serializers.Serializer):
|
||||||
data = self._map_serializer(field)
|
data = self.resolve_serializer(method, field, nested=True)
|
||||||
data['type'] = 'object'
|
data['type'] = 'object'
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -261,7 +355,6 @@ class AutoSchema(ViewInspector):
|
||||||
'enum': list(field.choices),
|
'enum': list(field.choices),
|
||||||
}
|
}
|
||||||
|
|
||||||
# ListField.
|
|
||||||
if isinstance(field, serializers.ListField):
|
if isinstance(field, serializers.ListField):
|
||||||
mapping = {
|
mapping = {
|
||||||
'type': 'array',
|
'type': 'array',
|
||||||
|
@ -355,6 +448,10 @@ class AutoSchema(ViewInspector):
|
||||||
'format': 'binary'
|
'format': 'binary'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isinstance(field, serializers.SerializerMethodField):
|
||||||
|
method = getattr(field.parent, field.method_name)
|
||||||
|
return self._map_type_hint(method)
|
||||||
|
|
||||||
# Simplest cases, default to 'string' type:
|
# Simplest cases, default to 'string' type:
|
||||||
FIELD_CLASS_SCHEMA_TYPE = {
|
FIELD_CLASS_SCHEMA_TYPE = {
|
||||||
serializers.BooleanField: 'boolean',
|
serializers.BooleanField: 'boolean',
|
||||||
|
@ -370,7 +467,7 @@ class AutoSchema(ViewInspector):
|
||||||
if field.min_value:
|
if field.min_value:
|
||||||
content['minimum'] = field.min_value
|
content['minimum'] = field.min_value
|
||||||
|
|
||||||
def _map_serializer(self, serializer):
|
def _map_serializer(self, method, serializer, nested=False):
|
||||||
# Assuming we have a valid serializer instance.
|
# Assuming we have a valid serializer instance.
|
||||||
# TODO:
|
# TODO:
|
||||||
# - field is Nested or List serializer.
|
# - field is Nested or List serializer.
|
||||||
|
@ -386,7 +483,7 @@ class AutoSchema(ViewInspector):
|
||||||
if field.required:
|
if field.required:
|
||||||
required.append(field.field_name)
|
required.append(field.field_name)
|
||||||
|
|
||||||
schema = self._map_field(field)
|
schema = self._map_field(method, field)
|
||||||
if field.read_only:
|
if field.read_only:
|
||||||
schema['readOnly'] = True
|
schema['readOnly'] = True
|
||||||
if field.write_only:
|
if field.write_only:
|
||||||
|
@ -404,15 +501,12 @@ class AutoSchema(ViewInspector):
|
||||||
result = {
|
result = {
|
||||||
'properties': properties
|
'properties': properties
|
||||||
}
|
}
|
||||||
if required:
|
if required and method != 'PATCH' and not nested:
|
||||||
result['required'] = required
|
result['required'] = required
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _map_field_validators(self, field, schema):
|
def _map_field_validators(self, field, schema):
|
||||||
"""
|
|
||||||
map field validators
|
|
||||||
"""
|
|
||||||
for v in field.validators:
|
for v in field.validators:
|
||||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||||
|
@ -446,6 +540,30 @@ class AutoSchema(ViewInspector):
|
||||||
schema['maximum'] = int(digits * '9') + 1
|
schema['maximum'] = int(digits * '9') + 1
|
||||||
schema['minimum'] = -schema['maximum']
|
schema['minimum'] = -schema['maximum']
|
||||||
|
|
||||||
|
def _map_type_hint(self, method, hint=None):
|
||||||
|
if not hint:
|
||||||
|
hint = typing.get_type_hints(method).get('return')
|
||||||
|
|
||||||
|
if hint in TYPE_MAPPING:
|
||||||
|
return TYPE_MAPPING[hint]
|
||||||
|
elif hint.__origin__ is typing.Union:
|
||||||
|
sub_hints = [
|
||||||
|
self._map_type_hint(method, sub_hint)
|
||||||
|
for sub_hint in hint.__args__ if sub_hint is not type(None) # noqa
|
||||||
|
]
|
||||||
|
if type(None) in hint.__args__ and len(sub_hints) == 1:
|
||||||
|
return {**sub_hints[0], 'nullable': True}
|
||||||
|
elif type(None) in hint.__args__:
|
||||||
|
return {'oneOf': [{**sub_hint, 'nullable': True} for sub_hint in sub_hints]}
|
||||||
|
else:
|
||||||
|
return {'oneOf': sub_hints}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'type hint for SerializerMethodField function "{}" is unknown. '
|
||||||
|
'defaulting to string.'.format(method.__name__)
|
||||||
|
)
|
||||||
|
return {'type': 'string'}
|
||||||
|
|
||||||
def _get_paginator(self):
|
def _get_paginator(self):
|
||||||
pagination_class = getattr(self.view, 'pagination_class', None)
|
pagination_class = getattr(self.view, 'pagination_class', None)
|
||||||
if pagination_class:
|
if pagination_class:
|
||||||
|
@ -473,82 +591,195 @@ class AutoSchema(ViewInspector):
|
||||||
try:
|
try:
|
||||||
return view.get_serializer()
|
return view.get_serializer()
|
||||||
except exceptions.APIException:
|
except exceptions.APIException:
|
||||||
warnings.warn('{}.get_serializer() raised an exception during '
|
warnings.warn(
|
||||||
'schema generation. Serializer fields will not be '
|
'{}.get_serializer() raised an exception during '
|
||||||
'generated for {} {}.'
|
'schema generation. Serializer fields will not be '
|
||||||
.format(view.__class__.__name__, method, path))
|
'generated for {} {}.'.format(view.__class__.__name__, method, path)
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_request_body(self, path, method):
|
def _get_request_body(self, path, method):
|
||||||
if method not in ('PUT', 'PATCH', 'POST'):
|
if method not in ('PUT', 'PATCH', 'POST'):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
self.request_media_types = self.map_parsers(path, method)
|
request_media_types = self.map_parsers(path, method)
|
||||||
|
|
||||||
serializer = self._get_serializer(path, method)
|
serializer = self._get_serializer(path, method)
|
||||||
|
|
||||||
if not isinstance(serializer, serializers.Serializer):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
content = self._map_serializer(serializer)
|
|
||||||
# No required fields for PATCH
|
|
||||||
if method == 'PATCH':
|
|
||||||
content.pop('required', None)
|
|
||||||
# No read_only fields for request.
|
|
||||||
for name, schema in content['properties'].copy().items():
|
|
||||||
if 'readOnly' in schema:
|
|
||||||
del content['properties'][name]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'content': {
|
|
||||||
ct: {'schema': content}
|
|
||||||
for ct in self.request_media_types
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_responses(self, path, method):
|
|
||||||
# TODO: Handle multiple codes and pagination classes.
|
|
||||||
if method == 'DELETE':
|
|
||||||
return {
|
|
||||||
'204': {
|
|
||||||
'description': ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.response_media_types = self.map_renderers(path, method)
|
|
||||||
|
|
||||||
item_schema = {}
|
|
||||||
serializer = self._get_serializer(path, method)
|
|
||||||
|
|
||||||
if isinstance(serializer, serializers.Serializer):
|
if isinstance(serializer, serializers.Serializer):
|
||||||
item_schema = self._map_serializer(serializer)
|
schema = self.resolve_serializer(method, serializer)
|
||||||
# No write_only fields for response.
|
else:
|
||||||
for name, schema in item_schema['properties'].copy().items():
|
warnings.warn(
|
||||||
if 'writeOnly' in schema:
|
'could not resolve request body for {} {}. defaulting to generic '
|
||||||
del item_schema['properties'][name]
|
'free-form object. (maybe annotate a Serializer class?)'.format(method, path)
|
||||||
if 'required' in item_schema:
|
)
|
||||||
item_schema['required'] = [f for f in item_schema['required'] if f != name]
|
schema = {
|
||||||
|
'type': 'object',
|
||||||
|
'additionalProperties': {}, # https://github.com/swagger-api/swagger-codegen/issues/1318
|
||||||
|
'description': 'Unspecified request body',
|
||||||
|
}
|
||||||
|
|
||||||
|
# serializer has no fields so skip content enumeration
|
||||||
|
if not schema:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'content': {mt: {'schema': schema} for mt in request_media_types}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_response_bodies(self, path, method):
|
||||||
|
response_serializers = self.get_response_serializers(path, method)
|
||||||
|
|
||||||
|
if isinstance(response_serializers, serializers.Serializer) or isinstance(response_serializers, PolymorphicResponse):
|
||||||
|
if method == 'DELETE':
|
||||||
|
return {'204': {'description': 'No response body'}}
|
||||||
|
return {'200': self._get_response_for_code(path, method, response_serializers)}
|
||||||
|
elif isinstance(response_serializers, dict):
|
||||||
|
# custom handling for overriding default return codes with @extend_schema
|
||||||
|
return {
|
||||||
|
code: self._get_response_for_code(path, method, serializer)
|
||||||
|
for code, serializer in response_serializers.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'could not resolve response for {} {}. defaulting '
|
||||||
|
'to generic free-form object.'.format(method, path)
|
||||||
|
)
|
||||||
|
schema = {
|
||||||
|
'type': 'object',
|
||||||
|
'description': 'Unspecified response body',
|
||||||
|
}
|
||||||
|
return {'200': self._get_response_for_code(path, method, schema)}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_response_for_code(self, path, method, serializer_instance):
|
||||||
|
if not serializer_instance:
|
||||||
|
return {'description': 'No response body'}
|
||||||
|
elif isinstance(serializer_instance, serializers.Serializer):
|
||||||
|
schema = self.resolve_serializer(method, serializer_instance)
|
||||||
|
if not schema:
|
||||||
|
return {'description': 'No response body'}
|
||||||
|
elif isinstance(serializer_instance, PolymorphicResponse):
|
||||||
|
# custom handling for @extend_schema's injection of polymorphic responses
|
||||||
|
schemas = []
|
||||||
|
|
||||||
|
for serializer in serializer_instance.serializers:
|
||||||
|
assert isinstance(serializer, serializers.Serializer)
|
||||||
|
schema_option = self.resolve_serializer(method, serializer)
|
||||||
|
if schema_option:
|
||||||
|
schemas.append(schema_option)
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
'oneOf': schemas,
|
||||||
|
'discriminator': {
|
||||||
|
'propertyName': serializer_instance.resource_type_field_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif isinstance(serializer_instance, dict):
|
||||||
|
# bypass processing and use given schema directly
|
||||||
|
schema = serializer_instance
|
||||||
|
else:
|
||||||
|
raise ValueError('Serializer type unsupported')
|
||||||
|
|
||||||
if is_list_view(path, method, self.view):
|
if is_list_view(path, method, self.view):
|
||||||
response_schema = {
|
schema = {
|
||||||
'type': 'array',
|
'type': 'array',
|
||||||
'items': item_schema,
|
'items': schema,
|
||||||
}
|
}
|
||||||
paginator = self._get_paginator()
|
paginator = self._get_paginator()
|
||||||
if paginator:
|
if paginator:
|
||||||
response_schema = paginator.get_paginated_response_schema(response_schema)
|
schema = paginator.get_paginated_response_schema(schema)
|
||||||
else:
|
|
||||||
response_schema = item_schema
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'200': {
|
'content': {
|
||||||
'content': {
|
mt: {'schema': schema} for mt in self.response_media_types
|
||||||
ct: {'schema': response_schema}
|
},
|
||||||
for ct in self.response_media_types
|
# Description is required by spec, but descriptions for each response code don't really
|
||||||
},
|
# fit into our model. Description is therefore put into the higher level slots.
|
||||||
# description is a mandatory property,
|
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
|
||||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
|
'description': ''
|
||||||
# TODO: put something meaningful into it
|
}
|
||||||
'description': ""
|
|
||||||
|
def _get_serializer_name(self, method, serializer, nested):
|
||||||
|
name = serializer.__class__.__name__
|
||||||
|
|
||||||
|
if name.endswith('Serializer'):
|
||||||
|
name = name[:-10]
|
||||||
|
if method == 'PATCH' and not nested:
|
||||||
|
name = 'Patched' + name
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
def resolve_authentication(self, method, authentication):
|
||||||
|
if authentication not in AUTHENTICATION_SCHEMES:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
auth_scheme = AUTHENTICATION_SCHEMES.get(authentication)
|
||||||
|
|
||||||
|
if not auth_scheme:
|
||||||
|
raise ValueError('no auth scheme registered for {}'.format(authentication.__name__))
|
||||||
|
|
||||||
|
if auth_scheme.name not in self.registry.security_schemes:
|
||||||
|
self.registry.security_schemes[auth_scheme.name] = auth_scheme.schema
|
||||||
|
|
||||||
|
return {auth_scheme.name: []}
|
||||||
|
|
||||||
|
def resolve_serializer(self, method, serializer, nested=False):
|
||||||
|
name = self._get_serializer_name(method, serializer, nested)
|
||||||
|
|
||||||
|
if name not in self.registry.schemas:
|
||||||
|
# add placeholder to prevent recursion loop
|
||||||
|
self.registry.schemas[name] = None
|
||||||
|
|
||||||
|
mapped = self._map_serializer(method, serializer, nested)
|
||||||
|
# empty serializer - usually a transactional serializer.
|
||||||
|
# no need to put it explicitly in the spec
|
||||||
|
if not mapped['properties']:
|
||||||
|
del self.registry.schemas[name]
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
self.registry.schemas[name] = mapped
|
||||||
|
|
||||||
|
return {'$ref': '#/components/schemas/{}'.format(name)}
|
||||||
|
|
||||||
|
|
||||||
|
class PolymorphicAutoSchema(AutoSchema):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def resolve_serializer(self, method, serializer, nested=False):
|
||||||
|
try:
|
||||||
|
from rest_polymorphic.serializers import PolymorphicSerializer
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn('rest_polymorphic package required for PolymorphicAutoSchema')
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(serializer, PolymorphicSerializer):
|
||||||
|
return self._resolve_polymorphic_serializer(method, serializer, nested)
|
||||||
|
else:
|
||||||
|
return super().resolve_serializer(method, serializer, nested)
|
||||||
|
|
||||||
|
def _resolve_polymorphic_serializer(self, method, serializer, nested):
|
||||||
|
polymorphic_names = []
|
||||||
|
|
||||||
|
for poly_model, poly_serializer in serializer.model_serializer_mapping.items():
|
||||||
|
name = self._get_serializer_name(method, poly_serializer, nested)
|
||||||
|
|
||||||
|
if name not in self.registry.schemas:
|
||||||
|
# add placeholder to prevent recursion loop
|
||||||
|
self.registry.schemas[name] = None
|
||||||
|
# append the type field to serializer fields
|
||||||
|
mapped = self._map_serializer(method, poly_serializer, nested)
|
||||||
|
mapped['properties'][serializer.resource_type_field_name] = {'type': 'string'}
|
||||||
|
self.registry.schemas[name] = mapped
|
||||||
|
|
||||||
|
polymorphic_names.append(name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'oneOf': [
|
||||||
|
{'$ref': '#/components/schemas/{}'.format(name)} for name in polymorphic_names
|
||||||
|
],
|
||||||
|
'discriminator': {
|
||||||
|
'propertyName': serializer.resource_type_field_name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
165
rest_framework/schemas/openapi_utils.py
Normal file
165
rest_framework/schemas/openapi_utils.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
|
from decimal import Decimal
|
||||||
|
from uuid import UUID
|
||||||
|
from datetime import datetime, date
|
||||||
|
|
||||||
|
from rest_framework import authentication
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
VALID_TYPES = ['integer', 'number', 'string', 'boolean']
|
||||||
|
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
float: {'type': 'number', 'format': 'float'},
|
||||||
|
bool: {'type': 'boolean'},
|
||||||
|
str: {'type': 'string'},
|
||||||
|
bytes: {'type': 'string', 'format': 'binary'}, # or byte?
|
||||||
|
int: {'type': 'integer'},
|
||||||
|
UUID: {'type': 'string', 'format': 'uuid'},
|
||||||
|
Decimal: {'type': 'number', 'format': 'double'},
|
||||||
|
datetime: {'type': 'string', 'format': 'date-time'},
|
||||||
|
date: {'type': 'string', 'format': 'date'},
|
||||||
|
None: {},
|
||||||
|
type(None): {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenApiAuthenticationScheme:
|
||||||
|
authentication_class = None
|
||||||
|
name = None
|
||||||
|
schema = None
|
||||||
|
|
||||||
|
|
||||||
|
class SessionAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||||
|
authentication_class = authentication.SessionAuthentication
|
||||||
|
name = 'cookieAuth'
|
||||||
|
schema = {
|
||||||
|
'type': 'apiKey',
|
||||||
|
'in': 'cookie',
|
||||||
|
'name': 'Session',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BasicAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||||
|
authentication_class = authentication.BasicAuthentication
|
||||||
|
name = 'basicAuth'
|
||||||
|
schema = {
|
||||||
|
'type': 'http',
|
||||||
|
'scheme': 'basic',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||||
|
authentication_class = authentication.TokenAuthentication
|
||||||
|
name = 'tokenAuth'
|
||||||
|
schema = {
|
||||||
|
'type': 'http',
|
||||||
|
'scheme': 'bearer',
|
||||||
|
'bearerFormat': 'Token',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PolymorphicResponse:
|
||||||
|
def __init__(self, serializers, resource_type_field_name):
|
||||||
|
self.serializers = serializers
|
||||||
|
self.resource_type_field_name = resource_type_field_name
|
||||||
|
|
||||||
|
|
||||||
|
class OpenApiSchemaBase:
|
||||||
|
""" reusable base class for objects that can be translated to a schema """
|
||||||
|
def to_schema(self):
|
||||||
|
raise NotImplementedError('translation to schema required.')
|
||||||
|
|
||||||
|
|
||||||
|
class QueryParameter(OpenApiSchemaBase):
|
||||||
|
def __init__(self, name, description='', required=False, type=str):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.required = required
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def to_schema(self):
|
||||||
|
if self.type not in TYPE_MAPPING:
|
||||||
|
warnings.warn('{} not a mappable type'.format(self.type))
|
||||||
|
return {
|
||||||
|
'name': self.name,
|
||||||
|
'in': 'query',
|
||||||
|
'description': self.description,
|
||||||
|
'required': self.required,
|
||||||
|
'schema': TYPE_MAPPING.get(self.type)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extend_schema(
|
||||||
|
operation=None,
|
||||||
|
extra_parameters=None,
|
||||||
|
responses=None,
|
||||||
|
request=None,
|
||||||
|
auth=None,
|
||||||
|
description=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
TODO some heavy explaining
|
||||||
|
|
||||||
|
:param operation:
|
||||||
|
:param extra_parameters:
|
||||||
|
:param responses:
|
||||||
|
:param request:
|
||||||
|
:param auth:
|
||||||
|
:param description:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(f):
|
||||||
|
class ExtendedSchema(api_settings.DEFAULT_SCHEMA_CLASS):
|
||||||
|
def get_operation(self, path, method):
|
||||||
|
if operation:
|
||||||
|
return operation
|
||||||
|
return super().get_operation(path, method)
|
||||||
|
|
||||||
|
def get_extra_parameters(self, path, method):
|
||||||
|
if extra_parameters:
|
||||||
|
return [
|
||||||
|
p.to_schema() if isinstance(p, OpenApiSchemaBase) else p for p in extra_parameters
|
||||||
|
]
|
||||||
|
return super().get_extra_parameters(path, method)
|
||||||
|
|
||||||
|
def get_auth(self, path, method):
|
||||||
|
if auth:
|
||||||
|
return auth
|
||||||
|
return super().get_auth(path, method)
|
||||||
|
|
||||||
|
def get_request_serializer(self, path, method):
|
||||||
|
if request:
|
||||||
|
return request
|
||||||
|
return super().get_request_serializer(path, method)
|
||||||
|
|
||||||
|
def get_response_serializers(self, path, method):
|
||||||
|
if responses:
|
||||||
|
return responses
|
||||||
|
return super().get_response_serializers(path, method)
|
||||||
|
|
||||||
|
def get_description(self, path, method):
|
||||||
|
if description:
|
||||||
|
return description
|
||||||
|
return super().get_description(path, method)
|
||||||
|
|
||||||
|
if inspect.isclass(f):
|
||||||
|
class ExtendedView(f):
|
||||||
|
schema = ExtendedSchema()
|
||||||
|
|
||||||
|
return ExtendedView
|
||||||
|
elif callable(f):
|
||||||
|
# custom actions have kwargs in their context, others don't. create it so our create_view
|
||||||
|
# implementation can overwrite the default schema
|
||||||
|
if not hasattr(f, 'kwargs'):
|
||||||
|
f.kwargs = {}
|
||||||
|
|
||||||
|
# this simulates what @action is actually doing. somewhere along the line in this process
|
||||||
|
# the schema is picked up from kwargs and used. it's involved my dear friends.
|
||||||
|
f.kwargs['schema'] = ExtendedSchema()
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
|
@ -53,6 +53,12 @@ DEFAULTS = {
|
||||||
|
|
||||||
# Schema
|
# Schema
|
||||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
|
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
|
||||||
|
'SCHEMA_PATH_PREFIX': r'^\/api\/(?:v[0-9\.\_\-]+\/)?',
|
||||||
|
'SCHEMA_AUTHENTICATION_CLASSES': [
|
||||||
|
'rest_framework.schemas.openapi_utils.SessionAuthenticationScheme',
|
||||||
|
'rest_framework.schemas.openapi_utils.BasicAuthenticationScheme',
|
||||||
|
'rest_framework.schemas.openapi_utils.TokenAuthenticationScheme',
|
||||||
|
],
|
||||||
|
|
||||||
# Throttling
|
# Throttling
|
||||||
'DEFAULT_THROTTLE_RATES': {
|
'DEFAULT_THROTTLE_RATES': {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user