mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Chore: drop coreapi.py, openapi.py from codebase as REST Framework does not support Schema generation
				
					
				
			This commit is contained in:
		
							parent
							
								
									c4de49e861
								
							
						
					
					
						commit
						14956c864a
					
				| 
						 | 
					@ -1,181 +0,0 @@
 | 
				
			||||||
from collections import Counter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.db import models
 | 
					 | 
				
			||||||
from django.utils.encoding import force_str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from rest_framework import serializers
 | 
					 | 
				
			||||||
from rest_framework.compat import coreschema
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def common_path(paths):
 | 
					 | 
				
			||||||
    split_paths = [path.strip('/').split('/') for path in paths]
 | 
					 | 
				
			||||||
    s1 = min(split_paths)
 | 
					 | 
				
			||||||
    s2 = max(split_paths)
 | 
					 | 
				
			||||||
    common = s1
 | 
					 | 
				
			||||||
    for i, c in enumerate(s1):
 | 
					 | 
				
			||||||
        if c != s2[i]:
 | 
					 | 
				
			||||||
            common = s1[:i]
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
    return '/' + '/'.join(common)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_custom_action(action):
 | 
					 | 
				
			||||||
    return action not in {
 | 
					 | 
				
			||||||
        'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def distribute_links(obj):
 | 
					 | 
				
			||||||
    for key, value in obj.items():
 | 
					 | 
				
			||||||
        distribute_links(value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for preferred_key, link in obj.links:
 | 
					 | 
				
			||||||
        key = obj.get_available_key(preferred_key)
 | 
					 | 
				
			||||||
        obj[key] = link
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
INSERT_INTO_COLLISION_FMT = """
 | 
					 | 
				
			||||||
Schema Naming Collision.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
coreapi.Link for URL path {value_url} cannot be inserted into schema.
 | 
					 | 
				
			||||||
Position conflicts with coreapi.Link for URL path {target_url}.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Attempted to insert link with keys: {keys}.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
 | 
					 | 
				
			||||||
to customise schema structure.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LinkNode(dict):
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					 | 
				
			||||||
        self.links = []
 | 
					 | 
				
			||||||
        self.methods_counter = Counter()
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_available_key(self, preferred_key):
 | 
					 | 
				
			||||||
        if preferred_key not in self:
 | 
					 | 
				
			||||||
            return preferred_key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            current_val = self.methods_counter[preferred_key]
 | 
					 | 
				
			||||||
            self.methods_counter[preferred_key] += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            key = '{}_{}'.format(preferred_key, current_val)
 | 
					 | 
				
			||||||
            if key not in self:
 | 
					 | 
				
			||||||
                return key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def insert_into(target, keys, value):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Nested dictionary insertion.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> example = {}
 | 
					 | 
				
			||||||
    >>> insert_into(example, ['a', 'b', 'c'], 123)
 | 
					 | 
				
			||||||
    >>> example
 | 
					 | 
				
			||||||
    LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    for key in keys[:-1]:
 | 
					 | 
				
			||||||
        if key not in target:
 | 
					 | 
				
			||||||
            target[key] = LinkNode()
 | 
					 | 
				
			||||||
        target = target[key]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        target.links.append((keys[-1], value))
 | 
					 | 
				
			||||||
    except TypeError:
 | 
					 | 
				
			||||||
        msg = INSERT_INTO_COLLISION_FMT.format(
 | 
					 | 
				
			||||||
            value_url=value.url,
 | 
					 | 
				
			||||||
            target_url=target.url,
 | 
					 | 
				
			||||||
            keys=keys
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        raise ValueError(msg)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# View Inspectors #
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def field_to_schema(field):
 | 
					 | 
				
			||||||
    title = force_str(field.label) if field.label else ''
 | 
					 | 
				
			||||||
    description = force_str(field.help_text) if field.help_text else ''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
 | 
					 | 
				
			||||||
        child_schema = field_to_schema(field.child)
 | 
					 | 
				
			||||||
        return coreschema.Array(
 | 
					 | 
				
			||||||
            items=child_schema,
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.DictField):
 | 
					 | 
				
			||||||
        return coreschema.Object(
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.Serializer):
 | 
					 | 
				
			||||||
        return coreschema.Object(
 | 
					 | 
				
			||||||
            properties={
 | 
					 | 
				
			||||||
                key: field_to_schema(value)
 | 
					 | 
				
			||||||
                for key, value
 | 
					 | 
				
			||||||
                in field.fields.items()
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.ManyRelatedField):
 | 
					 | 
				
			||||||
        related_field_schema = field_to_schema(field.child_relation)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return coreschema.Array(
 | 
					 | 
				
			||||||
            items=related_field_schema,
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.PrimaryKeyRelatedField):
 | 
					 | 
				
			||||||
        schema_cls = coreschema.String
 | 
					 | 
				
			||||||
        model = getattr(field.queryset, 'model', None)
 | 
					 | 
				
			||||||
        if model is not None:
 | 
					 | 
				
			||||||
            model_field = model._meta.pk
 | 
					 | 
				
			||||||
            if isinstance(model_field, models.AutoField):
 | 
					 | 
				
			||||||
                schema_cls = coreschema.Integer
 | 
					 | 
				
			||||||
        return schema_cls(title=title, description=description)
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.RelatedField):
 | 
					 | 
				
			||||||
        return coreschema.String(title=title, description=description)
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.MultipleChoiceField):
 | 
					 | 
				
			||||||
        return coreschema.Array(
 | 
					 | 
				
			||||||
            items=coreschema.Enum(enum=list(field.choices)),
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.ChoiceField):
 | 
					 | 
				
			||||||
        return coreschema.Enum(
 | 
					 | 
				
			||||||
            enum=list(field.choices),
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.BooleanField):
 | 
					 | 
				
			||||||
        return coreschema.Boolean(title=title, description=description)
 | 
					 | 
				
			||||||
    elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
 | 
					 | 
				
			||||||
        return coreschema.Number(title=title, description=description)
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.IntegerField):
 | 
					 | 
				
			||||||
        return coreschema.Integer(title=title, description=description)
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.DateField):
 | 
					 | 
				
			||||||
        return coreschema.String(
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description,
 | 
					 | 
				
			||||||
            format='date'
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.DateTimeField):
 | 
					 | 
				
			||||||
        return coreschema.String(
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description,
 | 
					 | 
				
			||||||
            format='date-time'
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    elif isinstance(field, serializers.JSONField):
 | 
					 | 
				
			||||||
        return coreschema.Object(title=title, description=description)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if field.style.get('base_template') == 'textarea.html':
 | 
					 | 
				
			||||||
        return coreschema.String(
 | 
					 | 
				
			||||||
            title=title,
 | 
					 | 
				
			||||||
            description=description,
 | 
					 | 
				
			||||||
            format='textarea'
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return coreschema.String(title=title, description=description)
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,721 +0,0 @@
 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from decimal import Decimal
 | 
					 | 
				
			||||||
from operator import attrgetter
 | 
					 | 
				
			||||||
from urllib.parse import urljoin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.core.validators import (
 | 
					 | 
				
			||||||
    DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
 | 
					 | 
				
			||||||
    MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from django.db import models
 | 
					 | 
				
			||||||
from django.utils.encoding import force_str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from rest_framework import exceptions, renderers, serializers
 | 
					 | 
				
			||||||
from rest_framework.compat import inflection, uritemplate
 | 
					 | 
				
			||||||
from rest_framework.fields import _UnvalidatedField, empty
 | 
					 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .generators import BaseSchemaGenerator
 | 
					 | 
				
			||||||
from .inspectors import ViewInspector
 | 
					 | 
				
			||||||
from .utils import get_pk_description, is_list_view
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SchemaGenerator(BaseSchemaGenerator):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_info(self):
 | 
					 | 
				
			||||||
        # Title and version are required by openapi specification 3.x
 | 
					 | 
				
			||||||
        info = {
 | 
					 | 
				
			||||||
            'title': self.title or '',
 | 
					 | 
				
			||||||
            'version': self.version or ''
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.description is not None:
 | 
					 | 
				
			||||||
            info['description'] = self.description
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return info
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def check_duplicate_operation_id(self, paths):
 | 
					 | 
				
			||||||
        ids = {}
 | 
					 | 
				
			||||||
        for route in paths:
 | 
					 | 
				
			||||||
            for method in paths[route]:
 | 
					 | 
				
			||||||
                if 'operationId' not in paths[route][method]:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                operation_id = paths[route][method]['operationId']
 | 
					 | 
				
			||||||
                if operation_id in ids:
 | 
					 | 
				
			||||||
                    warnings.warn(
 | 
					 | 
				
			||||||
                        'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
 | 
					 | 
				
			||||||
                        '\tRoute: {route1}, Method: {method1}\n'
 | 
					 | 
				
			||||||
                        '\tRoute: {route2}, Method: {method2}\n'
 | 
					 | 
				
			||||||
                        '\tAn operationId has to be unique across your schema. Your schema may not work in other tools.'
 | 
					 | 
				
			||||||
                        .format(
 | 
					 | 
				
			||||||
                            route1=ids[operation_id]['route'],
 | 
					 | 
				
			||||||
                            method1=ids[operation_id]['method'],
 | 
					 | 
				
			||||||
                            route2=route,
 | 
					 | 
				
			||||||
                            method2=method,
 | 
					 | 
				
			||||||
                            operation_id=operation_id
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                ids[operation_id] = {
 | 
					 | 
				
			||||||
                    'route': route,
 | 
					 | 
				
			||||||
                    'method': method
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_schema(self, request=None, public=False):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Generate a OpenAPI schema.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self._initialise_endpoints()
 | 
					 | 
				
			||||||
        components_schemas = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Iterate endpoints generating per method path operations.
 | 
					 | 
				
			||||||
        paths = {}
 | 
					 | 
				
			||||||
        _, view_endpoints = self._get_paths_and_endpoints(None if public else request)
 | 
					 | 
				
			||||||
        for path, method, view in view_endpoints:
 | 
					 | 
				
			||||||
            if not self.has_view_permissions(path, method, view):
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            operation = view.schema.get_operation(path, method)
 | 
					 | 
				
			||||||
            components = view.schema.get_components(path, method)
 | 
					 | 
				
			||||||
            for k in components.keys():
 | 
					 | 
				
			||||||
                if k not in components_schemas:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                if components_schemas[k] == components[k]:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                warnings.warn('Schema component "{}" has been overridden with a different value.'.format(k))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            components_schemas.update(components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Normalise path for any provided mount url.
 | 
					 | 
				
			||||||
            if path.startswith('/'):
 | 
					 | 
				
			||||||
                path = path[1:]
 | 
					 | 
				
			||||||
            path = urljoin(self.url or '/', path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            paths.setdefault(path, {})
 | 
					 | 
				
			||||||
            paths[path][method.lower()] = operation
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.check_duplicate_operation_id(paths)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Compile final schema.
 | 
					 | 
				
			||||||
        schema = {
 | 
					 | 
				
			||||||
            'openapi': '3.0.2',
 | 
					 | 
				
			||||||
            'info': self.get_info(),
 | 
					 | 
				
			||||||
            'paths': paths,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if len(components_schemas) > 0:
 | 
					 | 
				
			||||||
            schema['components'] = {
 | 
					 | 
				
			||||||
                'schemas': components_schemas
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return schema
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# View Inspectors
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AutoSchema(ViewInspector):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, tags=None, operation_id_base=None, component_name=None):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        :param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
 | 
					 | 
				
			||||||
        :param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if tags and not all(isinstance(tag, str) for tag in tags):
 | 
					 | 
				
			||||||
            raise ValueError('tags must be a list or tuple of string.')
 | 
					 | 
				
			||||||
        self._tags = tags
 | 
					 | 
				
			||||||
        self.operation_id_base = operation_id_base
 | 
					 | 
				
			||||||
        self.component_name = component_name
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    request_media_types = []
 | 
					 | 
				
			||||||
    response_media_types = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    method_mapping = {
 | 
					 | 
				
			||||||
        'get': 'retrieve',
 | 
					 | 
				
			||||||
        'post': 'create',
 | 
					 | 
				
			||||||
        'put': 'update',
 | 
					 | 
				
			||||||
        'patch': 'partialUpdate',
 | 
					 | 
				
			||||||
        'delete': 'destroy',
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_operation(self, path, method):
 | 
					 | 
				
			||||||
        operation = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        operation['operationId'] = self.get_operation_id(path, method)
 | 
					 | 
				
			||||||
        operation['description'] = self.get_description(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        parameters = []
 | 
					 | 
				
			||||||
        parameters += self.get_path_parameters(path, method)
 | 
					 | 
				
			||||||
        parameters += self.get_pagination_parameters(path, method)
 | 
					 | 
				
			||||||
        parameters += self.get_filter_parameters(path, method)
 | 
					 | 
				
			||||||
        operation['parameters'] = parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request_body = self.get_request_body(path, method)
 | 
					 | 
				
			||||||
        if request_body:
 | 
					 | 
				
			||||||
            operation['requestBody'] = request_body
 | 
					 | 
				
			||||||
        operation['responses'] = self.get_responses(path, method)
 | 
					 | 
				
			||||||
        operation['tags'] = self.get_tags(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return operation
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_component_name(self, serializer):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Compute the component's name from the serializer.
 | 
					 | 
				
			||||||
        Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if self.component_name is not None:
 | 
					 | 
				
			||||||
            return self.component_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # use the serializer's class name as the component name.
 | 
					 | 
				
			||||||
        component_name = serializer.__class__.__name__
 | 
					 | 
				
			||||||
        # We remove the "serializer" string from the class name.
 | 
					 | 
				
			||||||
        pattern = re.compile("serializer", re.IGNORECASE)
 | 
					 | 
				
			||||||
        component_name = pattern.sub("", component_name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if component_name == "":
 | 
					 | 
				
			||||||
            raise Exception(
 | 
					 | 
				
			||||||
                '"{}" is an invalid class name for schema generation. '
 | 
					 | 
				
			||||||
                'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
 | 
					 | 
				
			||||||
                .format(serializer.__class__.__name__)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return component_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_components(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return components with their properties from the serializer.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if method.lower() == 'delete':
 | 
					 | 
				
			||||||
            return {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request_serializer = self.get_request_serializer(path, method)
 | 
					 | 
				
			||||||
        response_serializer = self.get_response_serializer(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        components = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(request_serializer, serializers.Serializer):
 | 
					 | 
				
			||||||
            component_name = self.get_component_name(request_serializer)
 | 
					 | 
				
			||||||
            content = self.map_serializer(request_serializer)
 | 
					 | 
				
			||||||
            components.setdefault(component_name, content)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(response_serializer, serializers.Serializer):
 | 
					 | 
				
			||||||
            component_name = self.get_component_name(response_serializer)
 | 
					 | 
				
			||||||
            content = self.map_serializer(response_serializer)
 | 
					 | 
				
			||||||
            components.setdefault(component_name, content)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return components
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _to_camel_case(self, snake_str):
 | 
					 | 
				
			||||||
        components = snake_str.split('_')
 | 
					 | 
				
			||||||
        # We capitalize the first letter of each component except the first one
 | 
					 | 
				
			||||||
        # with the 'title' method and join them together.
 | 
					 | 
				
			||||||
        return components[0] + ''.join(x.title() for x in components[1:])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_operation_id_base(self, path, method, action):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Compute the base part for operation ID from the model, serializer or view name.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        model = getattr(getattr(self.view, 'queryset', None), 'model', None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.operation_id_base is not None:
 | 
					 | 
				
			||||||
            name = self.operation_id_base
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Try to deduce the ID from the view's model
 | 
					 | 
				
			||||||
        elif model is not None:
 | 
					 | 
				
			||||||
            name = model.__name__
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Try with the serializer class name
 | 
					 | 
				
			||||||
        elif self.get_serializer(path, method) is not None:
 | 
					 | 
				
			||||||
            name = self.get_serializer(path, method).__class__.__name__
 | 
					 | 
				
			||||||
            if name.endswith('Serializer'):
 | 
					 | 
				
			||||||
                name = name[:-10]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Fallback to the view name
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            name = self.view.__class__.__name__
 | 
					 | 
				
			||||||
            if name.endswith('APIView'):
 | 
					 | 
				
			||||||
                name = name[:-7]
 | 
					 | 
				
			||||||
            elif name.endswith('View'):
 | 
					 | 
				
			||||||
                name = name[:-4]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
 | 
					 | 
				
			||||||
            # comes at the end of the name
 | 
					 | 
				
			||||||
            if name.endswith(action.title()):  # ListView, UpdateAPIView, ThingDelete ...
 | 
					 | 
				
			||||||
                name = name[:-len(action)]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if action == 'list':
 | 
					 | 
				
			||||||
            assert inflection, '`inflection` must be installed for OpenAPI schema support.'
 | 
					 | 
				
			||||||
            name = inflection.pluralize(name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_operation_id(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Compute an operation ID from the view type and get_operation_id_base method.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        method_name = getattr(self.view, 'action', method.lower())
 | 
					 | 
				
			||||||
        if is_list_view(path, method, self.view):
 | 
					 | 
				
			||||||
            action = 'list'
 | 
					 | 
				
			||||||
        elif method_name not in self.method_mapping:
 | 
					 | 
				
			||||||
            action = self._to_camel_case(method_name)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            action = self.method_mapping[method.lower()]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        name = self.get_operation_id_base(path, method, action)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return action + name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_path_parameters(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Return a list of parameters from templated path variables.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        model = getattr(getattr(self.view, 'queryset', None), 'model', None)
 | 
					 | 
				
			||||||
        parameters = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for variable in uritemplate.variables(path):
 | 
					 | 
				
			||||||
            description = ''
 | 
					 | 
				
			||||||
            if model is not None:  # TODO: test this.
 | 
					 | 
				
			||||||
                # Attempt to infer a field description if possible.
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    model_field = model._meta.get_field(variable)
 | 
					 | 
				
			||||||
                except Exception:
 | 
					 | 
				
			||||||
                    model_field = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if model_field is not None and model_field.help_text:
 | 
					 | 
				
			||||||
                    description = force_str(model_field.help_text)
 | 
					 | 
				
			||||||
                elif model_field is not None and model_field.primary_key:
 | 
					 | 
				
			||||||
                    description = get_pk_description(model, model_field)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            parameter = {
 | 
					 | 
				
			||||||
                "name": variable,
 | 
					 | 
				
			||||||
                "in": "path",
 | 
					 | 
				
			||||||
                "required": True,
 | 
					 | 
				
			||||||
                "description": description,
 | 
					 | 
				
			||||||
                'schema': {
 | 
					 | 
				
			||||||
                    'type': 'string',  # TODO: integer, pattern, ...
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            parameters.append(parameter)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_filter_parameters(self, path, method):
 | 
					 | 
				
			||||||
        if not self.allows_filters(path, method):
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
        parameters = []
 | 
					 | 
				
			||||||
        for filter_backend in self.view.filter_backends:
 | 
					 | 
				
			||||||
            parameters += filter_backend().get_schema_operation_parameters(self.view)
 | 
					 | 
				
			||||||
        return parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def allows_filters(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Determine whether to include filter Fields in schema.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Default implementation looks for ModelViewSet or GenericAPIView
 | 
					 | 
				
			||||||
        actions/methods that cause filtering on the default implementation.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if getattr(self.view, 'filter_backends', None) is None:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        if hasattr(self.view, 'action'):
 | 
					 | 
				
			||||||
            return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
 | 
					 | 
				
			||||||
        return method.lower() in ["get", "put", "patch", "delete"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_pagination_parameters(self, path, method):
 | 
					 | 
				
			||||||
        view = self.view
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not is_list_view(path, method, view):
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        paginator = self.get_paginator()
 | 
					 | 
				
			||||||
        if not paginator:
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return paginator.get_schema_operation_parameters(view)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_choicefield(self, field):
 | 
					 | 
				
			||||||
        choices = list(dict.fromkeys(field.choices))  # preserve order and remove duplicates
 | 
					 | 
				
			||||||
        if all(isinstance(choice, bool) for choice in choices):
 | 
					 | 
				
			||||||
            type = 'boolean'
 | 
					 | 
				
			||||||
        elif all(isinstance(choice, int) for choice in choices):
 | 
					 | 
				
			||||||
            type = 'integer'
 | 
					 | 
				
			||||||
        elif all(isinstance(choice, (int, float, Decimal)) for choice in choices):  # `number` includes `integer`
 | 
					 | 
				
			||||||
            # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
 | 
					 | 
				
			||||||
            type = 'number'
 | 
					 | 
				
			||||||
        elif all(isinstance(choice, str) for choice in choices):
 | 
					 | 
				
			||||||
            type = 'string'
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            type = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        mapping = {
 | 
					 | 
				
			||||||
            # The value of `enum` keyword MUST be an array and SHOULD be unique.
 | 
					 | 
				
			||||||
            # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
 | 
					 | 
				
			||||||
            'enum': choices
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If We figured out `type` then and only then we should set it. It must be a string.
 | 
					 | 
				
			||||||
        # Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
 | 
					 | 
				
			||||||
        # It is optional but it can not be null.
 | 
					 | 
				
			||||||
        # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
 | 
					 | 
				
			||||||
        if type:
 | 
					 | 
				
			||||||
            mapping['type'] = type
 | 
					 | 
				
			||||||
        return mapping
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_field(self, field):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Nested Serializers, `many` or not.
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.ListSerializer):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'array',
 | 
					 | 
				
			||||||
                'items': self.map_serializer(field.child)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.Serializer):
 | 
					 | 
				
			||||||
            data = self.map_serializer(field)
 | 
					 | 
				
			||||||
            data['type'] = 'object'
 | 
					 | 
				
			||||||
            return data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Related fields.
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.ManyRelatedField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'array',
 | 
					 | 
				
			||||||
                'items': self.map_field(field.child_relation)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.PrimaryKeyRelatedField):
 | 
					 | 
				
			||||||
            if getattr(field, "pk_field", False):
 | 
					 | 
				
			||||||
                return self.map_field(field=field.pk_field)
 | 
					 | 
				
			||||||
            model = getattr(field.queryset, 'model', None)
 | 
					 | 
				
			||||||
            if model is not None:
 | 
					 | 
				
			||||||
                model_field = model._meta.pk
 | 
					 | 
				
			||||||
                if isinstance(model_field, models.AutoField):
 | 
					 | 
				
			||||||
                    return {'type': 'integer'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # ChoiceFields (single and multiple).
 | 
					 | 
				
			||||||
        # Q:
 | 
					 | 
				
			||||||
        # - Is 'type' required?
 | 
					 | 
				
			||||||
        # - can we determine the TYPE of a choicefield?
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.MultipleChoiceField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'array',
 | 
					 | 
				
			||||||
                'items': self.map_choicefield(field)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.ChoiceField):
 | 
					 | 
				
			||||||
            return self.map_choicefield(field)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # ListField.
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.ListField):
 | 
					 | 
				
			||||||
            mapping = {
 | 
					 | 
				
			||||||
                'type': 'array',
 | 
					 | 
				
			||||||
                'items': {},
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            if not isinstance(field.child, _UnvalidatedField):
 | 
					 | 
				
			||||||
                mapping['items'] = self.map_field(field.child)
 | 
					 | 
				
			||||||
            return mapping
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # DateField and DateTimeField type is string
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.DateField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'date',
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.DateTimeField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'date-time',
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
 | 
					 | 
				
			||||||
        # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
 | 
					 | 
				
			||||||
        # see also: https://swagger.io/docs/specification/data-models/data-types/#string
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.EmailField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'email'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.URLField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'uri'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.UUIDField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'uuid'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.IPAddressField):
 | 
					 | 
				
			||||||
            content = {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            if field.protocol != 'both':
 | 
					 | 
				
			||||||
                content['format'] = field.protocol
 | 
					 | 
				
			||||||
            return content
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.DecimalField):
 | 
					 | 
				
			||||||
            if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
 | 
					 | 
				
			||||||
                content = {
 | 
					 | 
				
			||||||
                    'type': 'string',
 | 
					 | 
				
			||||||
                    'format': 'decimal',
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                content = {
 | 
					 | 
				
			||||||
                    'type': 'number'
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if field.decimal_places:
 | 
					 | 
				
			||||||
                content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
 | 
					 | 
				
			||||||
            if field.max_whole_digits:
 | 
					 | 
				
			||||||
                content['maximum'] = int(field.max_whole_digits * '9') + 1
 | 
					 | 
				
			||||||
                content['minimum'] = -content['maximum']
 | 
					 | 
				
			||||||
            self._map_min_max(field, content)
 | 
					 | 
				
			||||||
            return content
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.FloatField):
 | 
					 | 
				
			||||||
            content = {
 | 
					 | 
				
			||||||
                'type': 'number',
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            self._map_min_max(field, content)
 | 
					 | 
				
			||||||
            return content
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.IntegerField):
 | 
					 | 
				
			||||||
            content = {
 | 
					 | 
				
			||||||
                'type': 'integer'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            self._map_min_max(field, content)
 | 
					 | 
				
			||||||
            # 2147483647 is max for int32_size, so we use int64 for format
 | 
					 | 
				
			||||||
            if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
 | 
					 | 
				
			||||||
                content['format'] = 'int64'
 | 
					 | 
				
			||||||
            return content
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.FileField):
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                'type': 'string',
 | 
					 | 
				
			||||||
                'format': 'binary'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Simplest cases, default to 'string' type:
 | 
					 | 
				
			||||||
        FIELD_CLASS_SCHEMA_TYPE = {
 | 
					 | 
				
			||||||
            serializers.BooleanField: 'boolean',
 | 
					 | 
				
			||||||
            serializers.JSONField: 'object',
 | 
					 | 
				
			||||||
            serializers.DictField: 'object',
 | 
					 | 
				
			||||||
            serializers.HStoreField: 'object',
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _map_min_max(self, field, content):
 | 
					 | 
				
			||||||
        if field.max_value:
 | 
					 | 
				
			||||||
            content['maximum'] = field.max_value
 | 
					 | 
				
			||||||
        if field.min_value:
 | 
					 | 
				
			||||||
            content['minimum'] = field.min_value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_serializer(self, serializer):
 | 
					 | 
				
			||||||
        # Assuming we have a valid serializer instance.
 | 
					 | 
				
			||||||
        required = []
 | 
					 | 
				
			||||||
        properties = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for field in serializer.fields.values():
 | 
					 | 
				
			||||||
            if isinstance(field, serializers.HiddenField):
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if field.required and not serializer.partial:
 | 
					 | 
				
			||||||
                required.append(self.get_field_name(field))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            schema = self.map_field(field)
 | 
					 | 
				
			||||||
            if field.read_only:
 | 
					 | 
				
			||||||
                schema['readOnly'] = True
 | 
					 | 
				
			||||||
            if field.write_only:
 | 
					 | 
				
			||||||
                schema['writeOnly'] = True
 | 
					 | 
				
			||||||
            if field.allow_null:
 | 
					 | 
				
			||||||
                schema['nullable'] = True
 | 
					 | 
				
			||||||
            if field.default is not None and field.default != empty and not callable(field.default):
 | 
					 | 
				
			||||||
                schema['default'] = field.default
 | 
					 | 
				
			||||||
            if field.help_text:
 | 
					 | 
				
			||||||
                schema['description'] = str(field.help_text)
 | 
					 | 
				
			||||||
            self.map_field_validators(field, schema)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            properties[self.get_field_name(field)] = schema
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        result = {
 | 
					 | 
				
			||||||
            'type': 'object',
 | 
					 | 
				
			||||||
            'properties': properties
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if required:
 | 
					 | 
				
			||||||
            result['required'] = required
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_field_validators(self, field, schema):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        map field validators
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        for v in field.validators:
 | 
					 | 
				
			||||||
            # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
 | 
					 | 
				
			||||||
            # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
 | 
					 | 
				
			||||||
            if isinstance(v, EmailValidator):
 | 
					 | 
				
			||||||
                schema['format'] = 'email'
 | 
					 | 
				
			||||||
            if isinstance(v, URLValidator):
 | 
					 | 
				
			||||||
                schema['format'] = 'uri'
 | 
					 | 
				
			||||||
            if isinstance(v, RegexValidator):
 | 
					 | 
				
			||||||
                # In Python, the token \Z does what \z does in other engines.
 | 
					 | 
				
			||||||
                # https://stackoverflow.com/questions/53283160
 | 
					 | 
				
			||||||
                schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z')
 | 
					 | 
				
			||||||
            elif isinstance(v, MaxLengthValidator):
 | 
					 | 
				
			||||||
                attr_name = 'maxLength'
 | 
					 | 
				
			||||||
                if isinstance(field, serializers.ListField):
 | 
					 | 
				
			||||||
                    attr_name = 'maxItems'
 | 
					 | 
				
			||||||
                schema[attr_name] = v.limit_value
 | 
					 | 
				
			||||||
            elif isinstance(v, MinLengthValidator):
 | 
					 | 
				
			||||||
                attr_name = 'minLength'
 | 
					 | 
				
			||||||
                if isinstance(field, serializers.ListField):
 | 
					 | 
				
			||||||
                    attr_name = 'minItems'
 | 
					 | 
				
			||||||
                schema[attr_name] = v.limit_value
 | 
					 | 
				
			||||||
            elif isinstance(v, MaxValueValidator):
 | 
					 | 
				
			||||||
                schema['maximum'] = v.limit_value
 | 
					 | 
				
			||||||
            elif isinstance(v, MinValueValidator):
 | 
					 | 
				
			||||||
                schema['minimum'] = v.limit_value
 | 
					 | 
				
			||||||
            elif isinstance(v, DecimalValidator) and \
 | 
					 | 
				
			||||||
                    not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
 | 
					 | 
				
			||||||
                if v.decimal_places:
 | 
					 | 
				
			||||||
                    schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
 | 
					 | 
				
			||||||
                if v.max_digits:
 | 
					 | 
				
			||||||
                    digits = v.max_digits
 | 
					 | 
				
			||||||
                    if v.decimal_places is not None and v.decimal_places > 0:
 | 
					 | 
				
			||||||
                        digits -= v.decimal_places
 | 
					 | 
				
			||||||
                    schema['maximum'] = int(digits * '9') + 1
 | 
					 | 
				
			||||||
                    schema['minimum'] = -schema['maximum']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_field_name(self, field):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Override this method if you want to change schema field name.
 | 
					 | 
				
			||||||
        For example, convert snake_case field name to camelCase.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return field.field_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_paginator(self):
 | 
					 | 
				
			||||||
        pagination_class = getattr(self.view, 'pagination_class', None)
 | 
					 | 
				
			||||||
        if pagination_class:
 | 
					 | 
				
			||||||
            return pagination_class()
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_parsers(self, path, method):
 | 
					 | 
				
			||||||
        return list(map(attrgetter('media_type'), self.view.parser_classes))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_renderers(self, path, method):
 | 
					 | 
				
			||||||
        media_types = []
 | 
					 | 
				
			||||||
        for renderer in self.view.renderer_classes:
 | 
					 | 
				
			||||||
            # BrowsableAPIRenderer not relevant to OpenAPI spec
 | 
					 | 
				
			||||||
            if issubclass(renderer, renderers.BrowsableAPIRenderer):
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            media_types.append(renderer.media_type)
 | 
					 | 
				
			||||||
        return media_types
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_serializer(self, path, method):
 | 
					 | 
				
			||||||
        view = self.view
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not hasattr(view, 'get_serializer'):
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            return view.get_serializer()
 | 
					 | 
				
			||||||
        except exceptions.APIException:
 | 
					 | 
				
			||||||
            warnings.warn('{}.get_serializer() raised an exception during '
 | 
					 | 
				
			||||||
                          'schema generation. Serializer fields will not be '
 | 
					 | 
				
			||||||
                          'generated for {} {}.'
 | 
					 | 
				
			||||||
                          .format(view.__class__.__name__, method, path))
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_request_serializer(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Override this method if your view uses a different serializer for
 | 
					 | 
				
			||||||
        handling request body.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.get_serializer(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_response_serializer(self, path, method):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Override this method if your view uses a different serializer for
 | 
					 | 
				
			||||||
        populating response data.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.get_serializer(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_reference(self, serializer):
 | 
					 | 
				
			||||||
        return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_request_body(self, path, method):
 | 
					 | 
				
			||||||
        if method not in ('PUT', 'PATCH', 'POST'):
 | 
					 | 
				
			||||||
            return {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.request_media_types = self.map_parsers(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        serializer = self.get_request_serializer(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not isinstance(serializer, serializers.Serializer):
 | 
					 | 
				
			||||||
            item_schema = {}
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            item_schema = self.get_reference(serializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            'content': {
 | 
					 | 
				
			||||||
                ct: {'schema': item_schema}
 | 
					 | 
				
			||||||
                for ct in self.request_media_types
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_responses(self, path, method):
 | 
					 | 
				
			||||||
        if method == 'DELETE':
 | 
					 | 
				
			||||||
            return {
 | 
					 | 
				
			||||||
                '204': {
 | 
					 | 
				
			||||||
                    'description': ''
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.response_media_types = self.map_renderers(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        serializer = self.get_response_serializer(path, method)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not isinstance(serializer, serializers.Serializer):
 | 
					 | 
				
			||||||
            item_schema = {}
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            item_schema = self.get_reference(serializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if is_list_view(path, method, self.view):
 | 
					 | 
				
			||||||
            response_schema = {
 | 
					 | 
				
			||||||
                'type': 'array',
 | 
					 | 
				
			||||||
                'items': item_schema,
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            paginator = self.get_paginator()
 | 
					 | 
				
			||||||
            if paginator:
 | 
					 | 
				
			||||||
                response_schema = paginator.get_paginated_response_schema(response_schema)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            response_schema = item_schema
 | 
					 | 
				
			||||||
        status_code = '201' if method == 'POST' else '200'
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            status_code: {
 | 
					 | 
				
			||||||
                'content': {
 | 
					 | 
				
			||||||
                    ct: {'schema': response_schema}
 | 
					 | 
				
			||||||
                    for ct in self.response_media_types
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                # description is a mandatory property,
 | 
					 | 
				
			||||||
                # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
 | 
					 | 
				
			||||||
                # TODO: put something meaningful into it
 | 
					 | 
				
			||||||
                'description': ""
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_tags(self, path, method):
 | 
					 | 
				
			||||||
        # If user have specified tags, use them.
 | 
					 | 
				
			||||||
        if self._tags:
 | 
					 | 
				
			||||||
            return self._tags
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # First element of a specific path could be valid tag. This is a fallback solution.
 | 
					 | 
				
			||||||
        # PUT, PATCH, GET(Retrieve), DELETE:        /user_profile/{id}/       tags = [user-profile]
 | 
					 | 
				
			||||||
        # POST, GET(List):                          /user_profile/            tags = [user-profile]
 | 
					 | 
				
			||||||
        if path.startswith('/'):
 | 
					 | 
				
			||||||
            path = path[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return [path.split('/')[0].replace('_', '-')]
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user