mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-28 17:09:59 +03:00
refactor/extend/improve OpenApi3 spec generation
This commit is contained in:
parent
b135e0fa0a
commit
9486df8d04
|
@ -1,6 +1,10 @@
|
|||
import inspect
|
||||
import re
|
||||
import typing
|
||||
import warnings
|
||||
from operator import attrgetter
|
||||
from urllib.parse import urljoin
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||
|
@ -8,31 +12,69 @@ from django.core.validators import (
|
|||
)
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from rest_framework import exceptions, renderers, serializers
|
||||
from rest_framework import exceptions, renderers, serializers, permissions
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.fields import _UnvalidatedField, empty
|
||||
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.schemas.openapi_utils import TYPE_MAPPING, PolymorphicResponse
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
|
||||
def get_info(self):
|
||||
# Title and version are required by openapi specification 3.x
|
||||
info = {
|
||||
'title': self.title or '',
|
||||
'version': self.version or ''
|
||||
AUTHENTICATION_SCHEMES = {
|
||||
cls.authentication_class: cls
|
||||
for cls in [import_string(cls) for cls in api_settings.SCHEMA_AUTHENTICATION_CLASSES]
|
||||
}
|
||||
|
||||
if self.description is not None:
|
||||
info['description'] = self.description
|
||||
|
||||
return info
|
||||
class ComponentRegistry:
|
||||
def __init__(self):
|
||||
self.schemas = {}
|
||||
self.security_schemes = {}
|
||||
|
||||
def get_paths(self, request=None):
|
||||
def get_components(self):
|
||||
return {
|
||||
'securitySchemes': self.security_schemes,
|
||||
'schemas': self.schemas,
|
||||
}
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.registry = ComponentRegistry()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
customized create_view which is called when all routes are traversed. part of this
|
||||
is instatiating views with default params. in case of custom routes (@action) the
|
||||
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
|
||||
decorating plain views like retrieve, this initialization logic is not running.
|
||||
Therefore forcefully set the schema if @extend_schema decorator was used.
|
||||
"""
|
||||
view = super().create_view(callback, method, request)
|
||||
|
||||
# circumvent import issues by locally importing
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet, ViewSet
|
||||
|
||||
if isinstance(view, GenericViewSet) or isinstance(view, ViewSet):
|
||||
action = getattr(view, view.action)
|
||||
elif isinstance(view, APIView):
|
||||
action = getattr(view, method.lower())
|
||||
else:
|
||||
raise RuntimeError('not supported subclass. Must inherit from APIView')
|
||||
|
||||
if hasattr(action, 'kwargs') and 'schema' in action.kwargs:
|
||||
# might already be properly set in case of @action but overwrite for all cases
|
||||
view.schema = action.kwargs['schema']
|
||||
|
||||
return view
|
||||
|
||||
def parse(self, request=None):
|
||||
result = {}
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
@ -44,7 +86,10 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
operation = view.schema.get_operation(path, method)
|
||||
# keep reference to schema as every access yields a fresh object (descriptor )
|
||||
schema = view.schema
|
||||
schema.init(self.registry)
|
||||
operation = schema.get_operation(path, method)
|
||||
# Normalise path for any provided mount url.
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
@ -61,20 +106,21 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
paths = self.get_paths(None if public else request)
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
schema = {
|
||||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
'servers': [
|
||||
{'url': self.url or 'http://127.0.0.1:8000'},
|
||||
],
|
||||
'info': {
|
||||
'title': self.title or '',
|
||||
'version': self.version or '0.0.0', # fallback to prevent invalid schema
|
||||
'description': self.description or '',
|
||||
},
|
||||
'paths': self.parse(None if public else request),
|
||||
'components': self.registry.get_components(),
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
# View Inspectors
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
|
||||
|
@ -82,72 +128,117 @@ class AutoSchema(ViewInspector):
|
|||
response_media_types = []
|
||||
|
||||
method_mapping = {
|
||||
'get': 'Retrieve',
|
||||
'post': 'Create',
|
||||
'put': 'Update',
|
||||
'patch': 'PartialUpdate',
|
||||
'delete': 'Destroy',
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
def init(self, registry):
|
||||
self.registry = registry
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
operation['operationId'] = self._get_operation_id(path, method)
|
||||
operation['description'] = self.get_description(path, method)
|
||||
operation['parameters'] = sorted(
|
||||
[
|
||||
*self._get_path_parameters(path, method),
|
||||
*self._get_filter_parameters(path, method),
|
||||
*self._get_pagination_parameters(path, method),
|
||||
*self.get_extra_parameters(path, method),
|
||||
],
|
||||
key=lambda p: p.get('name')
|
||||
)
|
||||
|
||||
parameters = []
|
||||
parameters += self._get_path_parameters(path, method)
|
||||
parameters += self._get_pagination_parameters(path, method)
|
||||
parameters += self._get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
tags = self.get_tags(path, method)
|
||||
if tags:
|
||||
operation['tags'] = tags
|
||||
|
||||
request_body = self._get_request_body(path, method)
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
|
||||
auth = self.get_auth(path, method)
|
||||
if auth:
|
||||
operation['security'] = auth
|
||||
|
||||
self.response_media_types = self.map_renderers(path, method)
|
||||
|
||||
operation['responses'] = self._get_response_bodies(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
def get_extra_parameters(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
return []
|
||||
|
||||
def get_description(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
action_or_method = getattr(self.view, getattr(self.view, 'action', method.lower()), None)
|
||||
view_doc = inspect.getdoc(self.view) or ''
|
||||
action_doc = inspect.getdoc(action_or_method) or ''
|
||||
return view_doc + '\n\n' + action_doc if action_doc else view_doc
|
||||
|
||||
def get_auth(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
auth = []
|
||||
if hasattr(self.view, 'authentication_classes'):
|
||||
auth = [
|
||||
self.resolve_authentication(method, ac) for ac in self.view.authentication_classes
|
||||
]
|
||||
if hasattr(self.view, 'permission_classes'):
|
||||
perms = self.view.permission_classes
|
||||
if permissions.AllowAny in perms:
|
||||
auth.append({})
|
||||
elif permissions.IsAuthenticatedOrReadOnly in perms and method not in ('PUT', 'PATCH', 'POST'):
|
||||
auth.append({})
|
||||
return auth
|
||||
|
||||
def get_request_serializer(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
return self._get_serializer(path, method)
|
||||
|
||||
def get_response_serializers(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
return self._get_serializer(path, method)
|
||||
|
||||
def get_tags(self, path, method):
|
||||
""" override this for custom behaviour """
|
||||
path = re.sub(
|
||||
pattern=api_settings.SCHEMA_PATH_PREFIX,
|
||||
repl='',
|
||||
string=path,
|
||||
flags=re.IGNORECASE
|
||||
).split('/')
|
||||
return [path[0]]
|
||||
|
||||
def _get_operation_id(self, path, method):
|
||||
"""
|
||||
Compute an operation ID from the model, serializer or view name.
|
||||
"""
|
||||
method_name = getattr(self.view, 'action', method.lower())
|
||||
# remove path prefix
|
||||
sub_path = re.sub(
|
||||
pattern=api_settings.SCHEMA_PATH_PREFIX,
|
||||
repl='',
|
||||
string=path,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# cleanup, normalize and tokenize remaining parts.
|
||||
# replace dashes as they can be problematic later in code generation
|
||||
sub_path = sub_path.replace('-', '_').rstrip('/').lstrip('/')
|
||||
sub_path = sub_path.split('/') if sub_path else []
|
||||
# remove path variables
|
||||
sub_path = [p for p in sub_path if not p.startswith('{')]
|
||||
|
||||
if is_list_view(path, method, self.view):
|
||||
action = 'list'
|
||||
elif method_name not in self.method_mapping:
|
||||
action = method_name
|
||||
else:
|
||||
action = self.method_mapping[method.lower()]
|
||||
|
||||
# Try to deduce the ID from the view's model
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
if model is not None:
|
||||
name = model.__name__
|
||||
|
||||
# Try with the serializer class name
|
||||
elif hasattr(self.view, 'get_serializer_class'):
|
||||
name = self.view.get_serializer_class().__name__
|
||||
if name.endswith('Serializer'):
|
||||
name = name[:-10]
|
||||
|
||||
# Fallback to the view name
|
||||
else:
|
||||
name = self.view.__class__.__name__
|
||||
if name.endswith('APIView'):
|
||||
name = name[:-7]
|
||||
elif name.endswith('View'):
|
||||
name = name[:-4]
|
||||
|
||||
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
|
||||
# comes at the end of the name
|
||||
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
|
||||
name = name[:-len(action)]
|
||||
|
||||
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
|
||||
name += 's'
|
||||
|
||||
return action + name
|
||||
return '_'.join(sub_path + [action])
|
||||
|
||||
def _get_path_parameters(self, path, method):
|
||||
"""
|
||||
|
@ -160,6 +251,8 @@ class AutoSchema(ViewInspector):
|
|||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
schema = TYPE_MAPPING[str]
|
||||
|
||||
if model is not None: # TODO: test this.
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
|
@ -172,14 +265,16 @@ class AutoSchema(ViewInspector):
|
|||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
# TODO cover more cases
|
||||
if isinstance(model_field, models.UUIDField):
|
||||
schema = TYPE_MAPPING[UUID]
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
'schema': {
|
||||
'type': 'string', # TODO: integer, pattern, ...
|
||||
},
|
||||
'schema': schema,
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
|
@ -218,16 +313,15 @@ class AutoSchema(ViewInspector):
|
|||
|
||||
return paginator.get_schema_operation_parameters(view)
|
||||
|
||||
def _map_field(self, field):
|
||||
|
||||
def _map_field(self, method, field):
|
||||
# Nested Serializers, `many` or not.
|
||||
if isinstance(field, serializers.ListSerializer):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._map_serializer(field.child)
|
||||
'items': self.resolve_serializer(method, field.child)
|
||||
}
|
||||
if isinstance(field, serializers.Serializer):
|
||||
data = self._map_serializer(field)
|
||||
data = self.resolve_serializer(method, field, nested=True)
|
||||
data['type'] = 'object'
|
||||
return data
|
||||
|
||||
|
@ -261,7 +355,6 @@ class AutoSchema(ViewInspector):
|
|||
'enum': list(field.choices),
|
||||
}
|
||||
|
||||
# ListField.
|
||||
if isinstance(field, serializers.ListField):
|
||||
mapping = {
|
||||
'type': 'array',
|
||||
|
@ -355,6 +448,10 @@ class AutoSchema(ViewInspector):
|
|||
'format': 'binary'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.SerializerMethodField):
|
||||
method = getattr(field.parent, field.method_name)
|
||||
return self._map_type_hint(method)
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
|
@ -370,7 +467,7 @@ class AutoSchema(ViewInspector):
|
|||
if field.min_value:
|
||||
content['minimum'] = field.min_value
|
||||
|
||||
def _map_serializer(self, serializer):
|
||||
def _map_serializer(self, method, serializer, nested=False):
|
||||
# Assuming we have a valid serializer instance.
|
||||
# TODO:
|
||||
# - field is Nested or List serializer.
|
||||
|
@ -386,7 +483,7 @@ class AutoSchema(ViewInspector):
|
|||
if field.required:
|
||||
required.append(field.field_name)
|
||||
|
||||
schema = self._map_field(field)
|
||||
schema = self._map_field(method, field)
|
||||
if field.read_only:
|
||||
schema['readOnly'] = True
|
||||
if field.write_only:
|
||||
|
@ -404,15 +501,12 @@ class AutoSchema(ViewInspector):
|
|||
result = {
|
||||
'properties': properties
|
||||
}
|
||||
if required:
|
||||
if required and method != 'PATCH' and not nested:
|
||||
result['required'] = required
|
||||
|
||||
return result
|
||||
|
||||
def _map_field_validators(self, field, schema):
|
||||
"""
|
||||
map field validators
|
||||
"""
|
||||
for v in field.validators:
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
|
@ -446,6 +540,30 @@ class AutoSchema(ViewInspector):
|
|||
schema['maximum'] = int(digits * '9') + 1
|
||||
schema['minimum'] = -schema['maximum']
|
||||
|
||||
def _map_type_hint(self, method, hint=None):
|
||||
if not hint:
|
||||
hint = typing.get_type_hints(method).get('return')
|
||||
|
||||
if hint in TYPE_MAPPING:
|
||||
return TYPE_MAPPING[hint]
|
||||
elif hint.__origin__ is typing.Union:
|
||||
sub_hints = [
|
||||
self._map_type_hint(method, sub_hint)
|
||||
for sub_hint in hint.__args__ if sub_hint is not type(None) # noqa
|
||||
]
|
||||
if type(None) in hint.__args__ and len(sub_hints) == 1:
|
||||
return {**sub_hints[0], 'nullable': True}
|
||||
elif type(None) in hint.__args__:
|
||||
return {'oneOf': [{**sub_hint, 'nullable': True} for sub_hint in sub_hints]}
|
||||
else:
|
||||
return {'oneOf': sub_hints}
|
||||
else:
|
||||
warnings.warn(
|
||||
'type hint for SerializerMethodField function "{}" is unknown. '
|
||||
'defaulting to string.'.format(method.__name__)
|
||||
)
|
||||
return {'type': 'string'}
|
||||
|
||||
def _get_paginator(self):
|
||||
pagination_class = getattr(self.view, 'pagination_class', None)
|
||||
if pagination_class:
|
||||
|
@ -473,82 +591,195 @@ class AutoSchema(ViewInspector):
|
|||
try:
|
||||
return view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
warnings.warn(
|
||||
'{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
'generated for {} {}.'.format(view.__class__.__name__, method, path)
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_request_body(self, path, method):
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return {}
|
||||
|
||||
self.request_media_types = self.map_parsers(path, method)
|
||||
request_media_types = self.map_parsers(path, method)
|
||||
|
||||
serializer = self._get_serializer(path, method)
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return {}
|
||||
|
||||
content = self._map_serializer(serializer)
|
||||
# No required fields for PATCH
|
||||
if method == 'PATCH':
|
||||
content.pop('required', None)
|
||||
# No read_only fields for request.
|
||||
for name, schema in content['properties'].copy().items():
|
||||
if 'readOnly' in schema:
|
||||
del content['properties'][name]
|
||||
|
||||
return {
|
||||
'content': {
|
||||
ct: {'schema': content}
|
||||
for ct in self.request_media_types
|
||||
}
|
||||
}
|
||||
|
||||
def _get_responses(self, path, method):
|
||||
# TODO: Handle multiple codes and pagination classes.
|
||||
if method == 'DELETE':
|
||||
return {
|
||||
'204': {
|
||||
'description': ''
|
||||
}
|
||||
}
|
||||
|
||||
self.response_media_types = self.map_renderers(path, method)
|
||||
|
||||
item_schema = {}
|
||||
serializer = self._get_serializer(path, method)
|
||||
|
||||
if isinstance(serializer, serializers.Serializer):
|
||||
item_schema = self._map_serializer(serializer)
|
||||
# No write_only fields for response.
|
||||
for name, schema in item_schema['properties'].copy().items():
|
||||
if 'writeOnly' in schema:
|
||||
del item_schema['properties'][name]
|
||||
if 'required' in item_schema:
|
||||
item_schema['required'] = [f for f in item_schema['required'] if f != name]
|
||||
schema = self.resolve_serializer(method, serializer)
|
||||
else:
|
||||
warnings.warn(
|
||||
'could not resolve request body for {} {}. defaulting to generic '
|
||||
'free-form object. (maybe annotate a Serializer class?)'.format(method, path)
|
||||
)
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'additionalProperties': {}, # https://github.com/swagger-api/swagger-codegen/issues/1318
|
||||
'description': 'Unspecified request body',
|
||||
}
|
||||
|
||||
# serializer has no fields so skip content enumeration
|
||||
if not schema:
|
||||
return {}
|
||||
|
||||
return {
|
||||
'content': {mt: {'schema': schema} for mt in request_media_types}
|
||||
}
|
||||
|
||||
def _get_response_bodies(self, path, method):
|
||||
response_serializers = self.get_response_serializers(path, method)
|
||||
|
||||
if isinstance(response_serializers, serializers.Serializer) or isinstance(response_serializers, PolymorphicResponse):
|
||||
if method == 'DELETE':
|
||||
return {'204': {'description': 'No response body'}}
|
||||
return {'200': self._get_response_for_code(path, method, response_serializers)}
|
||||
elif isinstance(response_serializers, dict):
|
||||
# custom handling for overriding default return codes with @extend_schema
|
||||
return {
|
||||
code: self._get_response_for_code(path, method, serializer)
|
||||
for code, serializer in response_serializers.items()
|
||||
}
|
||||
else:
|
||||
warnings.warn(
|
||||
'could not resolve response for {} {}. defaulting '
|
||||
'to generic free-form object.'.format(method, path)
|
||||
)
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'description': 'Unspecified response body',
|
||||
}
|
||||
return {'200': self._get_response_for_code(path, method, schema)}
|
||||
|
||||
|
||||
def _get_response_for_code(self, path, method, serializer_instance):
|
||||
if not serializer_instance:
|
||||
return {'description': 'No response body'}
|
||||
elif isinstance(serializer_instance, serializers.Serializer):
|
||||
schema = self.resolve_serializer(method, serializer_instance)
|
||||
if not schema:
|
||||
return {'description': 'No response body'}
|
||||
elif isinstance(serializer_instance, PolymorphicResponse):
|
||||
# custom handling for @extend_schema's injection of polymorphic responses
|
||||
schemas = []
|
||||
|
||||
for serializer in serializer_instance.serializers:
|
||||
assert isinstance(serializer, serializers.Serializer)
|
||||
schema_option = self.resolve_serializer(method, serializer)
|
||||
if schema_option:
|
||||
schemas.append(schema_option)
|
||||
|
||||
schema = {
|
||||
'oneOf': schemas,
|
||||
'discriminator': {
|
||||
'propertyName': serializer_instance.resource_type_field_name
|
||||
}
|
||||
}
|
||||
elif isinstance(serializer_instance, dict):
|
||||
# bypass processing and use given schema directly
|
||||
schema = serializer_instance
|
||||
else:
|
||||
raise ValueError('Serializer type unsupported')
|
||||
|
||||
if is_list_view(path, method, self.view):
|
||||
response_schema = {
|
||||
schema = {
|
||||
'type': 'array',
|
||||
'items': item_schema,
|
||||
'items': schema,
|
||||
}
|
||||
paginator = self._get_paginator()
|
||||
if paginator:
|
||||
response_schema = paginator.get_paginated_response_schema(response_schema)
|
||||
else:
|
||||
response_schema = item_schema
|
||||
schema = paginator.get_paginated_response_schema(schema)
|
||||
|
||||
return {
|
||||
'200': {
|
||||
'content': {
|
||||
ct: {'schema': response_schema}
|
||||
for ct in self.response_media_types
|
||||
mt: {'schema': schema} for mt in self.response_media_types
|
||||
},
|
||||
# description is a mandatory property,
|
||||
# Description is required by spec, but descriptions for each response code don't really
|
||||
# fit into our model. Description is therefore put into the higher level slots.
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
|
||||
# TODO: put something meaningful into it
|
||||
'description': ""
|
||||
'description': ''
|
||||
}
|
||||
|
||||
def _get_serializer_name(self, method, serializer, nested):
|
||||
name = serializer.__class__.__name__
|
||||
|
||||
if name.endswith('Serializer'):
|
||||
name = name[:-10]
|
||||
if method == 'PATCH' and not nested:
|
||||
name = 'Patched' + name
|
||||
|
||||
return name
|
||||
|
||||
def resolve_authentication(self, method, authentication):
|
||||
if authentication not in AUTHENTICATION_SCHEMES:
|
||||
raise ValueError()
|
||||
|
||||
auth_scheme = AUTHENTICATION_SCHEMES.get(authentication)
|
||||
|
||||
if not auth_scheme:
|
||||
raise ValueError('no auth scheme registered for {}'.format(authentication.__name__))
|
||||
|
||||
if auth_scheme.name not in self.registry.security_schemes:
|
||||
self.registry.security_schemes[auth_scheme.name] = auth_scheme.schema
|
||||
|
||||
return {auth_scheme.name: []}
|
||||
|
||||
def resolve_serializer(self, method, serializer, nested=False):
|
||||
name = self._get_serializer_name(method, serializer, nested)
|
||||
|
||||
if name not in self.registry.schemas:
|
||||
# add placeholder to prevent recursion loop
|
||||
self.registry.schemas[name] = None
|
||||
|
||||
mapped = self._map_serializer(method, serializer, nested)
|
||||
# empty serializer - usually a transactional serializer.
|
||||
# no need to put it explicitly in the spec
|
||||
if not mapped['properties']:
|
||||
del self.registry.schemas[name]
|
||||
return {}
|
||||
else:
|
||||
self.registry.schemas[name] = mapped
|
||||
|
||||
return {'$ref': '#/components/schemas/{}'.format(name)}
|
||||
|
||||
|
||||
class PolymorphicAutoSchema(AutoSchema):
|
||||
"""
|
||||
|
||||
"""
|
||||
def resolve_serializer(self, method, serializer, nested=False):
|
||||
try:
|
||||
from rest_polymorphic.serializers import PolymorphicSerializer
|
||||
except ImportError:
|
||||
warnings.warn('rest_polymorphic package required for PolymorphicAutoSchema')
|
||||
raise
|
||||
|
||||
if isinstance(serializer, PolymorphicSerializer):
|
||||
return self._resolve_polymorphic_serializer(method, serializer, nested)
|
||||
else:
|
||||
return super().resolve_serializer(method, serializer, nested)
|
||||
|
||||
def _resolve_polymorphic_serializer(self, method, serializer, nested):
|
||||
polymorphic_names = []
|
||||
|
||||
for poly_model, poly_serializer in serializer.model_serializer_mapping.items():
|
||||
name = self._get_serializer_name(method, poly_serializer, nested)
|
||||
|
||||
if name not in self.registry.schemas:
|
||||
# add placeholder to prevent recursion loop
|
||||
self.registry.schemas[name] = None
|
||||
# append the type field to serializer fields
|
||||
mapped = self._map_serializer(method, poly_serializer, nested)
|
||||
mapped['properties'][serializer.resource_type_field_name] = {'type': 'string'}
|
||||
self.registry.schemas[name] = mapped
|
||||
|
||||
polymorphic_names.append(name)
|
||||
|
||||
return {
|
||||
'oneOf': [
|
||||
{'$ref': '#/components/schemas/{}'.format(name)} for name in polymorphic_names
|
||||
],
|
||||
'discriminator': {
|
||||
'propertyName': serializer.resource_type_field_name
|
||||
}
|
||||
}
|
||||
|
|
165
rest_framework/schemas/openapi_utils.py
Normal file
165
rest_framework/schemas/openapi_utils.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
from datetime import datetime, date
|
||||
|
||||
from rest_framework import authentication
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
VALID_TYPES = ['integer', 'number', 'string', 'boolean']
|
||||
|
||||
TYPE_MAPPING = {
|
||||
float: {'type': 'number', 'format': 'float'},
|
||||
bool: {'type': 'boolean'},
|
||||
str: {'type': 'string'},
|
||||
bytes: {'type': 'string', 'format': 'binary'}, # or byte?
|
||||
int: {'type': 'integer'},
|
||||
UUID: {'type': 'string', 'format': 'uuid'},
|
||||
Decimal: {'type': 'number', 'format': 'double'},
|
||||
datetime: {'type': 'string', 'format': 'date-time'},
|
||||
date: {'type': 'string', 'format': 'date'},
|
||||
None: {},
|
||||
type(None): {},
|
||||
}
|
||||
|
||||
|
||||
class OpenApiAuthenticationScheme:
|
||||
authentication_class = None
|
||||
name = None
|
||||
schema = None
|
||||
|
||||
|
||||
class SessionAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||
authentication_class = authentication.SessionAuthentication
|
||||
name = 'cookieAuth'
|
||||
schema = {
|
||||
'type': 'apiKey',
|
||||
'in': 'cookie',
|
||||
'name': 'Session',
|
||||
}
|
||||
|
||||
|
||||
class BasicAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||
authentication_class = authentication.BasicAuthentication
|
||||
name = 'basicAuth'
|
||||
schema = {
|
||||
'type': 'http',
|
||||
'scheme': 'basic',
|
||||
}
|
||||
|
||||
|
||||
class TokenAuthenticationScheme(OpenApiAuthenticationScheme):
|
||||
authentication_class = authentication.TokenAuthentication
|
||||
name = 'tokenAuth'
|
||||
schema = {
|
||||
'type': 'http',
|
||||
'scheme': 'bearer',
|
||||
'bearerFormat': 'Token',
|
||||
}
|
||||
|
||||
|
||||
class PolymorphicResponse:
|
||||
def __init__(self, serializers, resource_type_field_name):
|
||||
self.serializers = serializers
|
||||
self.resource_type_field_name = resource_type_field_name
|
||||
|
||||
|
||||
class OpenApiSchemaBase:
|
||||
""" reusable base class for objects that can be translated to a schema """
|
||||
def to_schema(self):
|
||||
raise NotImplementedError('translation to schema required.')
|
||||
|
||||
|
||||
class QueryParameter(OpenApiSchemaBase):
|
||||
def __init__(self, name, description='', required=False, type=str):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.required = required
|
||||
self.type = type
|
||||
|
||||
def to_schema(self):
|
||||
if self.type not in TYPE_MAPPING:
|
||||
warnings.warn('{} not a mappable type'.format(self.type))
|
||||
return {
|
||||
'name': self.name,
|
||||
'in': 'query',
|
||||
'description': self.description,
|
||||
'required': self.required,
|
||||
'schema': TYPE_MAPPING.get(self.type)
|
||||
}
|
||||
|
||||
|
||||
def extend_schema(
|
||||
operation=None,
|
||||
extra_parameters=None,
|
||||
responses=None,
|
||||
request=None,
|
||||
auth=None,
|
||||
description=None,
|
||||
):
|
||||
"""
|
||||
TODO some heavy explaining
|
||||
|
||||
:param operation:
|
||||
:param extra_parameters:
|
||||
:param responses:
|
||||
:param request:
|
||||
:param auth:
|
||||
:param description:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
class ExtendedSchema(api_settings.DEFAULT_SCHEMA_CLASS):
|
||||
def get_operation(self, path, method):
|
||||
if operation:
|
||||
return operation
|
||||
return super().get_operation(path, method)
|
||||
|
||||
def get_extra_parameters(self, path, method):
|
||||
if extra_parameters:
|
||||
return [
|
||||
p.to_schema() if isinstance(p, OpenApiSchemaBase) else p for p in extra_parameters
|
||||
]
|
||||
return super().get_extra_parameters(path, method)
|
||||
|
||||
def get_auth(self, path, method):
|
||||
if auth:
|
||||
return auth
|
||||
return super().get_auth(path, method)
|
||||
|
||||
def get_request_serializer(self, path, method):
|
||||
if request:
|
||||
return request
|
||||
return super().get_request_serializer(path, method)
|
||||
|
||||
def get_response_serializers(self, path, method):
|
||||
if responses:
|
||||
return responses
|
||||
return super().get_response_serializers(path, method)
|
||||
|
||||
def get_description(self, path, method):
|
||||
if description:
|
||||
return description
|
||||
return super().get_description(path, method)
|
||||
|
||||
if inspect.isclass(f):
|
||||
class ExtendedView(f):
|
||||
schema = ExtendedSchema()
|
||||
|
||||
return ExtendedView
|
||||
elif callable(f):
|
||||
# custom actions have kwargs in their context, others don't. create it so our create_view
|
||||
# implementation can overwrite the default schema
|
||||
if not hasattr(f, 'kwargs'):
|
||||
f.kwargs = {}
|
||||
|
||||
# this simulates what @action is actually doing. somewhere along the line in this process
|
||||
# the schema is picked up from kwargs and used. it's involved my dear friends.
|
||||
f.kwargs['schema'] = ExtendedSchema()
|
||||
return f
|
||||
else:
|
||||
return f
|
||||
|
||||
return decorator
|
|
@ -53,6 +53,12 @@ DEFAULTS = {
|
|||
|
||||
# Schema
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
|
||||
'SCHEMA_PATH_PREFIX': r'^\/api\/(?:v[0-9\.\_\-]+\/)?',
|
||||
'SCHEMA_AUTHENTICATION_CLASSES': [
|
||||
'rest_framework.schemas.openapi_utils.SessionAuthenticationScheme',
|
||||
'rest_framework.schemas.openapi_utils.BasicAuthenticationScheme',
|
||||
'rest_framework.schemas.openapi_utils.TokenAuthenticationScheme',
|
||||
],
|
||||
|
||||
# Throttling
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
|
|
Loading…
Reference in New Issue
Block a user