import warnings from collections import OrderedDict from decimal import Decimal from operator import attrgetter from urllib.parse import urljoin from django.core.validators import ( DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator, URLValidator ) from django.db import models from django.utils.encoding import force_str from rest_framework import exceptions, renderers, serializers from rest_framework.compat import uritemplate from rest_framework.fields import _UnvalidatedField, empty from .generators import BaseSchemaGenerator from .inspectors import ViewInspector from .utils import get_pk_description, is_list_view class SchemaGenerator(BaseSchemaGenerator): def get_info(self): # Title and version are required by openapi specification 3.x info = { 'title': self.title or '', 'version': self.version or '' } if self.description is not None: info['description'] = self.description return info def get_schema(self, request=None, public=False): """ Generate a OpenAPI schema. """ self._initialise_endpoints() # Iterate endpoints generating per method path operations. # TODO: …and reference components. paths = {} _, view_endpoints = self._get_paths_and_endpoints(None if public else request) for path, method, view in view_endpoints: if not self.has_view_permissions(path, method, view): continue operation = view.schema.get_operation(path, method) # Normalise path for any provided mount url. if path.startswith('/'): path = path[1:] path = urljoin(self.url or '/', path) paths.setdefault(path, {}) paths[path][method.lower()] = operation # Compile final schema. schema = { 'openapi': '3.0.2', 'info': self.get_info(), 'paths': paths, } return schema # View Inspectors class AutoSchema(ViewInspector): def __init__(self, tags=None): if tags and not all(isinstance(tag, str) for tag in tags): raise ValueError('tags must be a list or tuple of string.') self._tags = tags super().__init__() request_media_types = [] response_media_types = [] method_mapping = { 'get': 'Retrieve', 'post': 'Create', 'put': 'Update', 'patch': 'PartialUpdate', 'delete': 'Destroy', } def get_operation(self, path, method): operation = {} operation['operationId'] = self._get_operation_id(path, method) operation['description'] = self.get_description(path, method) parameters = [] parameters += self._get_path_parameters(path, method) parameters += self._get_pagination_parameters(path, method) parameters += self._get_filter_parameters(path, method) operation['parameters'] = parameters request_body = self._get_request_body(path, method) if request_body: operation['requestBody'] = request_body operation['responses'] = self._get_responses(path, method) operation['tags'] = self.get_tags(path, method) return operation def _get_operation_id(self, path, method): """ Compute an operation ID from the model, serializer or view name. """ method_name = getattr(self.view, 'action', method.lower()) if is_list_view(path, method, self.view): action = 'list' elif method_name not in self.method_mapping: action = method_name else: action = self.method_mapping[method.lower()] # Try to deduce the ID from the view's model model = getattr(getattr(self.view, 'queryset', None), 'model', None) if model is not None: name = model.__name__ # Try with the serializer class name elif self._get_serializer(path, method) is not None: name = self._get_serializer(path, method).__class__.__name__ if name.endswith('Serializer'): name = name[:-10] # Fallback to the view name else: name = self.view.__class__.__name__ if name.endswith('APIView'): name = name[:-7] elif name.endswith('View'): name = name[:-4] # Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly # comes at the end of the name if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ... name = name[:-len(action)] if action == 'list' and not name.endswith('s'): # listThings instead of listThing name += 's' return action + name def _get_path_parameters(self, path, method): """ Return a list of parameters from templated path variables. """ assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' model = getattr(getattr(self.view, 'queryset', None), 'model', None) parameters = [] for variable in uritemplate.variables(path): description = '' if model is not None: # TODO: test this. # Attempt to infer a field description if possible. try: model_field = model._meta.get_field(variable) except Exception: model_field = None if model_field is not None and model_field.help_text: description = force_str(model_field.help_text) elif model_field is not None and model_field.primary_key: description = get_pk_description(model, model_field) parameter = { "name": variable, "in": "path", "required": True, "description": description, 'schema': { 'type': 'string', # TODO: integer, pattern, ... }, } parameters.append(parameter) return parameters def _get_filter_parameters(self, path, method): if not self._allows_filters(path, method): return [] parameters = [] for filter_backend in self.view.filter_backends: parameters += filter_backend().get_schema_operation_parameters(self.view) return parameters def _allows_filters(self, path, method): """ Determine whether to include filter Fields in schema. Default implementation looks for ModelViewSet or GenericAPIView actions/methods that cause filtering on the default implementation. """ if getattr(self.view, 'filter_backends', None) is None: return False if hasattr(self.view, 'action'): return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] return method.lower() in ["get", "put", "patch", "delete"] def _get_pagination_parameters(self, path, method): view = self.view if not is_list_view(path, method, view): return [] paginator = self._get_paginator() if not paginator: return [] return paginator.get_schema_operation_parameters(view) def _map_choicefield(self, field): choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): type = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 type = 'number' elif all(isinstance(choice, str) for choice in choices): type = 'string' else: type = None mapping = { # The value of `enum` keyword MUST be an array and SHOULD be unique. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 'enum': choices } # If We figured out `type` then and only then we should set it. It must be a string. # Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type # It is optional but it can not be null. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 if type: mapping['type'] = type return mapping def _map_field(self, field): # Nested Serializers, `many` or not. if isinstance(field, serializers.ListSerializer): return { 'type': 'array', 'items': self._map_serializer(field.child) } if isinstance(field, serializers.Serializer): data = self._map_serializer(field) data['type'] = 'object' return data # Related fields. if isinstance(field, serializers.ManyRelatedField): return { 'type': 'array', 'items': self._map_field(field.child_relation) } if isinstance(field, serializers.PrimaryKeyRelatedField): model = getattr(field.queryset, 'model', None) if model is not None: model_field = model._meta.pk if isinstance(model_field, models.AutoField): return {'type': 'integer'} # ChoiceFields (single and multiple). # Q: # - Is 'type' required? # - can we determine the TYPE of a choicefield? if isinstance(field, serializers.MultipleChoiceField): return { 'type': 'array', 'items': self._map_choicefield(field) } if isinstance(field, serializers.ChoiceField): return self._map_choicefield(field) # ListField. if isinstance(field, serializers.ListField): mapping = { 'type': 'array', 'items': {}, } if not isinstance(field.child, _UnvalidatedField): mapping['items'] = self._map_field(field.child) return mapping # DateField and DateTimeField type is string if isinstance(field, serializers.DateField): return { 'type': 'string', 'format': 'date', } if isinstance(field, serializers.DateTimeField): return { 'type': 'string', 'format': 'date-time', } # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types # see also: https://swagger.io/docs/specification/data-models/data-types/#string if isinstance(field, serializers.EmailField): return { 'type': 'string', 'format': 'email' } if isinstance(field, serializers.URLField): return { 'type': 'string', 'format': 'uri' } if isinstance(field, serializers.UUIDField): return { 'type': 'string', 'format': 'uuid' } if isinstance(field, serializers.IPAddressField): content = { 'type': 'string', } if field.protocol != 'both': content['format'] = field.protocol return content # DecimalField has multipleOf based on decimal_places if isinstance(field, serializers.DecimalField): content = { 'type': 'number' } if field.decimal_places: content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') if field.max_whole_digits: content['maximum'] = int(field.max_whole_digits * '9') + 1 content['minimum'] = -content['maximum'] self._map_min_max(field, content) return content if isinstance(field, serializers.FloatField): content = { 'type': 'number' } self._map_min_max(field, content) return content if isinstance(field, serializers.IntegerField): content = { 'type': 'integer' } self._map_min_max(field, content) # 2147483647 is max for int32_size, so we use int64 for format if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647: content['format'] = 'int64' return content if isinstance(field, serializers.FileField): return { 'type': 'string', 'format': 'binary' } # Simplest cases, default to 'string' type: FIELD_CLASS_SCHEMA_TYPE = { serializers.BooleanField: 'boolean', serializers.JSONField: 'object', serializers.DictField: 'object', serializers.HStoreField: 'object', } return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} def _map_min_max(self, field, content): if field.max_value: content['maximum'] = field.max_value if field.min_value: content['minimum'] = field.min_value def _map_serializer(self, serializer): # Assuming we have a valid serializer instance. # TODO: # - field is Nested or List serializer. # - Handle read_only/write_only for request/response differences. # - could do this with readOnly/writeOnly and then filter dict. required = [] properties = {} for field in serializer.fields.values(): if isinstance(field, serializers.HiddenField): continue if field.required: required.append(field.field_name) schema = self._map_field(field) if field.read_only: schema['readOnly'] = True if field.write_only: schema['writeOnly'] = True if field.allow_null: schema['nullable'] = True if field.default is not None and field.default != empty and not callable(field.default): schema['default'] = field.default if field.help_text: schema['description'] = str(field.help_text) self._map_field_validators(field, schema) properties[field.field_name] = schema result = { 'type': 'object', 'properties': properties } if required: result['required'] = required return result def _map_field_validators(self, field, schema): """ map field validators """ for v in field.validators: # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types if isinstance(v, EmailValidator): schema['format'] = 'email' if isinstance(v, URLValidator): schema['format'] = 'uri' if isinstance(v, RegexValidator): schema['pattern'] = v.regex.pattern elif isinstance(v, MaxLengthValidator): attr_name = 'maxLength' if isinstance(field, serializers.ListField): attr_name = 'maxItems' schema[attr_name] = v.limit_value elif isinstance(v, MinLengthValidator): attr_name = 'minLength' if isinstance(field, serializers.ListField): attr_name = 'minItems' schema[attr_name] = v.limit_value elif isinstance(v, MaxValueValidator): schema['maximum'] = v.limit_value elif isinstance(v, MinValueValidator): schema['minimum'] = v.limit_value elif isinstance(v, DecimalValidator): if v.decimal_places: schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') if v.max_digits: digits = v.max_digits if v.decimal_places is not None and v.decimal_places > 0: digits -= v.decimal_places schema['maximum'] = int(digits * '9') + 1 schema['minimum'] = -schema['maximum'] def _get_paginator(self): pagination_class = getattr(self.view, 'pagination_class', None) if pagination_class: return pagination_class() return None def map_parsers(self, path, method): return list(map(attrgetter('media_type'), self.view.parser_classes)) def map_renderers(self, path, method): media_types = [] for renderer in self.view.renderer_classes: # BrowsableAPIRenderer not relevant to OpenAPI spec if renderer == renderers.BrowsableAPIRenderer: continue media_types.append(renderer.media_type) return media_types def _get_serializer(self, path, method): view = self.view if not hasattr(view, 'get_serializer'): return None try: return view.get_serializer() except exceptions.APIException: warnings.warn('{}.get_serializer() raised an exception during ' 'schema generation. Serializer fields will not be ' 'generated for {} {}.' .format(view.__class__.__name__, method, path)) return None def _get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): return {} self.request_media_types = self.map_parsers(path, method) serializer = self._get_serializer(path, method) if not isinstance(serializer, serializers.Serializer): return {} content = self._map_serializer(serializer) # No required fields for PATCH if method == 'PATCH': content.pop('required', None) # No read_only fields for request. for name, schema in content['properties'].copy().items(): if 'readOnly' in schema: del content['properties'][name] return { 'content': { ct: {'schema': content} for ct in self.request_media_types } } def _get_responses(self, path, method): # TODO: Handle multiple codes and pagination classes. if method == 'DELETE': return { '204': { 'description': '' } } self.response_media_types = self.map_renderers(path, method) item_schema = {} serializer = self._get_serializer(path, method) if isinstance(serializer, serializers.Serializer): item_schema = self._map_serializer(serializer) # No write_only fields for response. for name, schema in item_schema['properties'].copy().items(): if 'writeOnly' in schema: del item_schema['properties'][name] if 'required' in item_schema: item_schema['required'] = [f for f in item_schema['required'] if f != name] if is_list_view(path, method, self.view): response_schema = { 'type': 'array', 'items': item_schema, } paginator = self._get_paginator() if paginator: response_schema = paginator.get_paginated_response_schema(response_schema) else: response_schema = item_schema return { '200': { 'content': { ct: {'schema': response_schema} for ct in self.response_media_types }, # description is a mandatory property, # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject # TODO: put something meaningful into it 'description': "" } } def get_tags(self, path, method): # If user have specified tags, use them. if self._tags: return self._tags # First element of a specific path could be valid tag. This is a fallback solution. # PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile] # POST, GET(List): /user_profile/ tags = [user-profile] if path.startswith('/'): path = path[1:] return [path.split('/')[0].replace('_', '-')]