mirror of
synced 2025-03-09 14:05:47 +03:00
788 lines
29 KiB
788 lines
29 KiB
import re
import warnings
from collections import OrderedDict
from decimal import Decimal
from operator import attrgetter
from urllib.parse import urljoin
from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
from django.db import models
from django.utils.encoding import force_str
from rest_framework import (
RemovedInDRF314Warning, exceptions, renderers, serializers
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
# Title and version are required by openapi specification 3.x
info = {
'title': self.title or '',
'version': self.version or ''
if self.description is not None:
info['description'] = self.description
return info
def check_duplicate_operation_id(self, paths):
ids = {}
for route in paths:
for method in paths[route]:
if 'operationId' not in paths[route][method]:
operation_id = paths[route][method]['operationId']
if operation_id in ids:
'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
'\tRoute: {route1}, Method: {method1}\n'
'\tRoute: {route2}, Method: {method2}\n'
'\tAn operationId has to be unique accros your schema. Your schema may not work in other tools.'
ids[operation_id] = {
'route': route,
'method': method
def get_schema(self, request=None, public=False):
Generate a OpenAPI schema.
components_schemas = {}
# Iterate endpoints generating per method path operations.
paths = {}
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)
for k in components.keys():
if k not in components_schemas:
if components_schemas[k] == components[k]:
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)
paths.setdefault(path, {})
paths[path][method.lower()] = operation
# Compile final schema.
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
if len(components_schemas) > 0:
schema['components'] = {
'schemas': components_schemas
return schema
# View Inspectors
class AutoSchema(ViewInspector):
def __init__(self, tags=None, operation_id_base=None, component_name=None):
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
if tags and not all(isinstance(tag, str) for tag in tags):
raise ValueError('tags must be a list or tuple of string.')
self._tags = tags
self.operation_id_base = operation_id_base
self.component_name = component_name
request_media_types = []
response_media_types = []
method_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partialUpdate',
'delete': 'destroy',
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = []
parameters += self.get_path_parameters(path, method)
parameters += self.get_pagination_parameters(path, method)
parameters += self.get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self.get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self.get_responses(path, method)
operation['tags'] = self.get_tags(path, method)
return operation
def get_component_name(self, serializer):
Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
if self.component_name is not None:
return self.component_name
# use the serializer's class name as the component name.
component_name = serializer.__class__.__name__
# We remove the "serializer" string from the class name.
pattern = re.compile("serializer", re.IGNORECASE)
component_name = pattern.sub("", component_name)
if component_name == "":
raise Exception(
'"{}" is an invalid class name for schema generation. '
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
return component_name
def get_components(self, path, method):
Return components with their properties from the serializer.
if method.lower() == 'delete':
return {}
serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
return {}
component_name = self.get_component_name(serializer)
content = self.map_serializer(serializer)
return {component_name: content}
def _to_camel_case(self, snake_str):
components = snake_str.split('_')
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + ''.join(x.title() for x in components[1:])
def get_operation_id_base(self, path, method, action):
Compute the base part for operation ID from the model, serializer or view name.
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if self.operation_id_base is not None:
name = self.operation_id_base
# Try to deduce the ID from the view's model
elif model is not None:
name = model.__name__
# Try with the serializer class name
elif self.get_serializer(path, method) is not None:
name = self.get_serializer(path, method).__class__.__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
# comes at the end of the name
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
name += 's'
return name
def get_operation_id(self, path, method):
Compute an operation ID from the view type and get_operation_id_base method.
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = self._to_camel_case(method_name)
action = self.method_mapping[method.lower()]
name = self.get_operation_id_base(path, method, action)
return action + name
def get_path_parameters(self, path, method):
Return a list of parameters from templated path variables.
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_str(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
return parameters
def get_filter_parameters(self, path, method):
if not self.allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def allows_filters(self, path, method):
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
paginator = self.get_paginator()
if not paginator:
return []
return paginator.get_schema_operation_parameters(view)
def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices):
type = 'boolean'
elif all(isinstance(choice, int) for choice in choices):
type = 'integer'
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
type = 'number'
elif all(isinstance(choice, str) for choice in choices):
type = 'string'
type = None
mapping = {
# The value of `enum` keyword MUST be an array and SHOULD be unique.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
'enum': choices
# If We figured out `type` then and only then we should set it. It must be a string.
# Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
# It is optional but it can not be null.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
if type:
mapping['type'] = type
return mapping
def map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_serializer(field.child)
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self.map_field(field.child_relation)
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': self.map_choicefield(field)
if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field)
# ListField.
if isinstance(field, serializers.ListField):
mapping = {
'type': 'array',
'items': {},
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child)
return mapping
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {
'type': 'string',
'format': 'email'
if isinstance(field, serializers.URLField):
return {
'type': 'string',
'format': 'uri'
if isinstance(field, serializers.UUIDField):
return {
'type': 'string',
'format': 'uuid'
if isinstance(field, serializers.IPAddressField):
content = {
'type': 'string',
if field.protocol != 'both':
content['format'] = field.protocol
return content
if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
content = {
'type': 'number'
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number',
self._map_min_max(field, content)
return content
if isinstance(field, serializers.IntegerField):
content = {
'type': 'integer'
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
# Simplest cases, default to 'string' type:
serializers.BooleanField: 'boolean',
serializers.JSONField: 'object',
serializers.DictField: 'object',
serializers.HStoreField: 'object',
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_min_max(self, field, content):
if field.max_value:
content['maximum'] = field.max_value
if field.min_value:
content['minimum'] = field.min_value
def map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
if field.required:
schema = self.map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default is not None and field.default != empty and not callable(field.default):
schema['default'] = field.default
if field.help_text:
schema['description'] = str(field.help_text)
self.map_field_validators(field, schema)
properties[field.field_name] = schema
result = {
'type': 'object',
'properties': properties
if required:
result['required'] = required
return result
def map_field_validators(self, field, schema):
map field validators
for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
if isinstance(v, EmailValidator):
schema['format'] = 'email'
if isinstance(v, URLValidator):
schema['format'] = 'uri'
if isinstance(v, RegexValidator):
schema['pattern'] = v.regex.pattern
elif isinstance(v, MaxLengthValidator):
attr_name = 'maxLength'
if isinstance(field, serializers.ListField):
attr_name = 'maxItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MinLengthValidator):
attr_name = 'minLength'
if isinstance(field, serializers.ListField):
attr_name = 'minItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MaxValueValidator):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
return pagination_class()
return None
def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method):
media_types = []
for renderer in self.view.renderer_classes:
# BrowsableAPIRenderer not relevant to OpenAPI spec
if renderer == renderers.BrowsableAPIRenderer:
return media_types
def get_serializer(self, path, method):
view = self.view
if not hasattr(view, 'get_serializer'):
return None
return view.get_serializer()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
return None
def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
item_schema = self._get_reference(serializer)
return {
'content': {
ct: {'schema': item_schema}
for ct in self.request_media_types
def get_responses(self, path, method):
if method == 'DELETE':
return {
'204': {
'description': ''
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
item_schema = self._get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
response_schema = item_schema
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
def get_tags(self, path, method):
# If user have specified tags, use them.
if self._tags:
return self._tags
# First element of a specific path could be valid tag. This is a fallback solution.
# PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile]
# POST, GET(List): /user_profile/ tags = [user-profile]
if path.startswith('/'):
path = path[1:]
return [path.split('/')[0].replace('_', '-')]
def _get_path_parameters(self, path, method):
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_path_parameters(path, method)
def _get_filter_parameters(self, path, method):
"Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_filter_parameters(path, method)
def _get_responses(self, path, method):
"Method `_get_responses()` has been renamed to `get_responses()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_responses(path, method)
def _get_request_body(self, path, method):
"Method `_get_request_body()` has been renamed to `get_request_body()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_request_body(path, method)
def _get_serializer(self, path, method):
"Method `_get_serializer()` has been renamed to `get_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_serializer(path, method)
def _get_paginator(self):
"Method `_get_paginator()` has been renamed to `get_paginator()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_paginator()
def _map_field_validators(self, field, schema):
"Method `_map_field_validators()` has been renamed to `map_field_validators()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.map_field_validators(field, schema)
def _map_serializer(self, serializer):
"Method `_map_serializer()` has been renamed to `map_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.map_serializer(serializer)
def _map_field(self, field):
"Method `_map_field()` has been renamed to `map_field()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.map_field(field)
def _map_choicefield(self, field):
"Method `_map_choicefield()` has been renamed to `map_choicefield()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.map_choicefield(field)
def _get_pagination_parameters(self, path, method):
"Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.get_pagination_parameters(path, method)
def _allows_filters(self, path, method):
"Method `_allows_filters()` has been renamed to `allows_filters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
return self.allows_filters(path, method)