mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-06 13:23:18 +03:00
Merge 636cd9727a
into 3038494705
This commit is contained in:
commit
e2346da9bc
|
@ -1,6 +1,4 @@
|
|||
# Optional packages which may be used with REST framework.
|
||||
coreapi==2.3.1
|
||||
coreschema==0.0.4
|
||||
django-filter
|
||||
django-guardian>=2.4.0,<2.5
|
||||
inflection==0.5.1
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from rest_framework import parsers, renderers
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import ManualSchema
|
||||
from rest_framework.schemas import coreapi as coreapi_schema
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
|
@ -15,31 +12,6 @@ class ObtainAuthToken(APIView):
|
|||
renderer_classes = (renderers.JSONRenderer,)
|
||||
serializer_class = AuthTokenSerializer
|
||||
|
||||
if coreapi_schema.is_enabled():
|
||||
schema = ManualSchema(
|
||||
fields=[
|
||||
coreapi.Field(
|
||||
name="username",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Username",
|
||||
description="Valid username for authentication",
|
||||
),
|
||||
),
|
||||
coreapi.Field(
|
||||
name="password",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Password",
|
||||
description="Valid password for authentication",
|
||||
),
|
||||
),
|
||||
],
|
||||
encoding="application/json",
|
||||
)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
|
|
|
@ -23,32 +23,20 @@ except ImportError:
|
|||
postgres_fields = None
|
||||
|
||||
|
||||
# coreapi is required for CoreAPI schema generation
|
||||
try:
|
||||
import coreapi
|
||||
except ImportError:
|
||||
coreapi = None
|
||||
|
||||
# uritemplate is required for OpenAPI and CoreAPI schema generation
|
||||
# uritemplate is required for OpenAPI schema generation
|
||||
try:
|
||||
import uritemplate
|
||||
except ImportError:
|
||||
uritemplate = None
|
||||
|
||||
|
||||
# coreschema is optional
|
||||
try:
|
||||
import coreschema
|
||||
except ImportError:
|
||||
coreschema = None
|
||||
|
||||
|
||||
# pyyaml is optional
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
|
||||
# inflection is optional
|
||||
try:
|
||||
import inflection
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
from django.urls import include, path
|
||||
|
||||
from rest_framework.renderers import (
|
||||
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
|
||||
)
|
||||
from rest_framework.schemas import SchemaGenerator, get_schema_view
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
def get_docs_view(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
renderer_classes=None):
|
||||
|
||||
if renderer_classes is None:
|
||||
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
|
||||
|
||||
return get_schema_view(
|
||||
title=title,
|
||||
url=schema_url,
|
||||
urlconf=urlconf,
|
||||
description=description,
|
||||
renderer_classes=renderer_classes,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
|
||||
|
||||
def get_schemajs_view(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
|
||||
renderer_classes = [SchemaJSRenderer]
|
||||
|
||||
return get_schema_view(
|
||||
title=title,
|
||||
url=schema_url,
|
||||
urlconf=urlconf,
|
||||
description=description,
|
||||
renderer_classes=renderer_classes,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
|
||||
|
||||
def include_docs_urls(
|
||||
title=None, description=None, schema_url=None, urlconf=None,
|
||||
public=True, patterns=None, generator_class=SchemaGenerator,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
renderer_classes=None):
|
||||
docs_view = get_docs_view(
|
||||
title=title,
|
||||
description=description,
|
||||
schema_url=schema_url,
|
||||
urlconf=urlconf,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
renderer_classes=renderer_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
schema_js_view = get_schemajs_view(
|
||||
title=title,
|
||||
description=description,
|
||||
schema_url=schema_url,
|
||||
urlconf=urlconf,
|
||||
public=public,
|
||||
patterns=patterns,
|
||||
generator_class=generator_class,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
||||
urls = [
|
||||
path('', docs_view, name='docs-index'),
|
||||
path('schema.js', schema_js_view, name='schema-js')
|
||||
]
|
||||
return include((urls, 'api-docs'), namespace='api-docs')
|
|
@ -3,7 +3,6 @@ Provides generic filtering backends that can be used to filter the results
|
|||
returned by list views.
|
||||
"""
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
|
||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||
|
@ -14,8 +13,6 @@ from django.utils.encoding import force_str
|
|||
from django.utils.text import smart_split, unescape_string_literal
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import RemovedInDRF317Warning
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
@ -48,13 +45,6 @@ class BaseFilterBackend:
|
|||
"""
|
||||
raise NotImplementedError(".filter_queryset() must be overridden.")
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
@ -186,23 +176,6 @@ class SearchFilter(BaseFilterBackend):
|
|||
template = loader.get_template(self.template)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.search_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.search_title),
|
||||
description=force_str(self.search_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
|
@ -348,23 +321,6 @@ class OrderingFilter(BaseFilterBackend):
|
|||
context = self.get_template_context(request, queryset, view)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.ordering_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.ordering_title),
|
||||
description=force_str(self.ordering_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
from django.core.management.base import BaseCommand
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from rest_framework import renderers
|
||||
from rest_framework.schemas import coreapi
|
||||
from rest_framework.schemas.openapi import SchemaGenerator
|
||||
|
||||
OPENAPI_MODE = 'openapi'
|
||||
COREAPI_MODE = 'coreapi'
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generates configured API schema for project."
|
||||
|
||||
def get_mode(self):
|
||||
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--title', dest="title", default='', type=str)
|
||||
parser.add_argument('--url', dest="url", default=None, type=str)
|
||||
parser.add_argument('--description', dest="description", default=None, type=str)
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
|
||||
else:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
|
||||
parser.add_argument('--urlconf', dest="urlconf", default=None, type=str)
|
||||
parser.add_argument('--generator_class', dest="generator_class", default=None, type=str)
|
||||
parser.add_argument('--file', dest="file", default=None, type=str)
|
||||
parser.add_argument('--api_version', dest="api_version", default='', type=str)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
if options['generator_class']:
|
||||
generator_class = import_string(options['generator_class'])
|
||||
else:
|
||||
generator_class = self.get_generator_class()
|
||||
generator = generator_class(
|
||||
url=options['url'],
|
||||
title=options['title'],
|
||||
description=options['description'],
|
||||
urlconf=options['urlconf'],
|
||||
version=options['api_version'],
|
||||
)
|
||||
schema = generator.get_schema(request=None, public=True)
|
||||
renderer = self.get_renderer(options['format'])
|
||||
output = renderer.render(schema, renderer_context={})
|
||||
|
||||
if options['file']:
|
||||
with open(options['file'], 'wb') as f:
|
||||
f.write(output)
|
||||
else:
|
||||
self.stdout.write(output.decode())
|
||||
|
||||
def get_renderer(self, format):
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
renderer_cls = {
|
||||
'corejson': renderers.CoreJSONRenderer,
|
||||
'openapi': renderers.CoreAPIOpenAPIRenderer,
|
||||
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
renderer_cls = {
|
||||
'openapi': renderers.OpenAPIRenderer,
|
||||
'openapi-json': renderers.JSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
def get_generator_class(self):
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
return coreapi.SchemaGenerator
|
||||
return SchemaGenerator
|
|
@ -4,7 +4,6 @@ be used for paginated responses.
|
|||
"""
|
||||
|
||||
import contextlib
|
||||
import warnings
|
||||
from base64 import b64decode, b64encode
|
||||
from collections import namedtuple
|
||||
from urllib import parse
|
||||
|
@ -15,8 +14,6 @@ from django.template import loader
|
|||
from django.utils.encoding import force_str
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import RemovedInDRF317Warning
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -151,12 +148,6 @@ class BasePagination:
|
|||
def get_results(self, data):
|
||||
return data['results']
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
@ -313,36 +304,6 @@ class PageNumberPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
fields = [
|
||||
coreapi.Field(
|
||||
name=self.page_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page',
|
||||
description=force_str(self.page_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
fields.append(
|
||||
coreapi.Field(
|
||||
name=self.page_size_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page size',
|
||||
description=force_str(self.page_size_query_description)
|
||||
)
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
|
@ -530,32 +491,6 @@ class LimitOffsetPagination(BasePagination):
|
|||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.limit_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Limit',
|
||||
description=force_str(self.limit_query_description)
|
||||
)
|
||||
),
|
||||
coreapi.Field(
|
||||
name=self.offset_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Offset',
|
||||
description=force_str(self.offset_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
|
@ -933,36 +868,6 @@ class CursorPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
fields = [
|
||||
coreapi.Field(
|
||||
name=self.cursor_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title='Cursor',
|
||||
description=force_str(self.cursor_query_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
fields.append(
|
||||
coreapi.Field(
|
||||
name=self.page_size_query_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.Integer(
|
||||
title='Page size',
|
||||
description=force_str(self.page_size_query_description)
|
||||
)
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
|
|
|
@ -7,10 +7,8 @@ on the response, such as JSON encoded data or HTML output.
|
|||
REST framework also provides an HTML renderer that renders the browsable API.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import datetime
|
||||
from urllib import parse
|
||||
|
||||
from django import forms
|
||||
from django.conf import settings
|
||||
|
@ -18,14 +16,12 @@ from django.core.exceptions import ImproperlyConfigured
|
|||
from django.core.paginator import Page
|
||||
from django.template import engines, loader
|
||||
from django.urls import NoReverseMatch
|
||||
from django.utils.html import mark_safe
|
||||
from django.utils.http import parse_header_parameters
|
||||
from django.utils.safestring import SafeString
|
||||
|
||||
from rest_framework import VERSION, exceptions, serializers, status
|
||||
from rest_framework.compat import (
|
||||
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema,
|
||||
pygments_css, yaml
|
||||
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, pygments_css, yaml
|
||||
)
|
||||
from rest_framework.exceptions import ParseError
|
||||
from rest_framework.request import is_form_media_type, override_method
|
||||
|
@ -850,57 +846,6 @@ class AdminRenderer(BrowsableAPIRenderer):
|
|||
return
|
||||
|
||||
|
||||
class DocumentationRenderer(BaseRenderer):
|
||||
media_type = 'text/html'
|
||||
format = 'html'
|
||||
charset = 'utf-8'
|
||||
template = 'rest_framework/docs/index.html'
|
||||
error_template = 'rest_framework/docs/error.html'
|
||||
code_style = 'emacs'
|
||||
languages = ['shell', 'javascript', 'python']
|
||||
|
||||
def get_context(self, data, request):
|
||||
return {
|
||||
'document': data,
|
||||
'langs': self.languages,
|
||||
'lang_htmls': ["rest_framework/docs/langs/%s.html" % language for language in self.languages],
|
||||
'lang_intro_htmls': ["rest_framework/docs/langs/%s-intro.html" % language for language in self.languages],
|
||||
'code_style': pygments_css(self.code_style),
|
||||
'request': request
|
||||
}
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
if isinstance(data, coreapi.Document):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_context(data, renderer_context['request'])
|
||||
return template.render(context, request=renderer_context['request'])
|
||||
else:
|
||||
template = loader.get_template(self.error_template)
|
||||
context = {
|
||||
"data": data,
|
||||
"request": renderer_context['request'],
|
||||
"response": renderer_context['response'],
|
||||
"debug": settings.DEBUG,
|
||||
}
|
||||
return template.render(context, request=renderer_context['request'])
|
||||
|
||||
|
||||
class SchemaJSRenderer(BaseRenderer):
|
||||
media_type = 'application/javascript'
|
||||
format = 'javascript'
|
||||
charset = 'utf-8'
|
||||
template = 'rest_framework/schema.js'
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
codec = coreapi.codecs.CoreJSONCodec()
|
||||
schema = base64.b64encode(codec.encode(data)).decode('ascii')
|
||||
|
||||
template = loader.get_template(self.template)
|
||||
context = {'schema': mark_safe(schema)}
|
||||
request = renderer_context['request']
|
||||
return template.render(context, request=request)
|
||||
|
||||
|
||||
class MultiPartRenderer(BaseRenderer):
|
||||
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
|
||||
format = 'multipart'
|
||||
|
@ -921,139 +866,6 @@ class MultiPartRenderer(BaseRenderer):
|
|||
return encode_multipart(self.BOUNDARY, data)
|
||||
|
||||
|
||||
class CoreJSONRenderer(BaseRenderer):
|
||||
media_type = 'application/coreapi+json'
|
||||
charset = None
|
||||
format = 'corejson'
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
indent = bool(renderer_context.get('indent', 0))
|
||||
codec = coreapi.codecs.CoreJSONCodec()
|
||||
return codec.dump(data, indent=indent)
|
||||
|
||||
|
||||
class _BaseOpenAPIRenderer:
|
||||
def get_schema(self, instance):
|
||||
CLASS_TO_TYPENAME = {
|
||||
coreschema.Object: 'object',
|
||||
coreschema.Array: 'array',
|
||||
coreschema.Number: 'number',
|
||||
coreschema.Integer: 'integer',
|
||||
coreschema.String: 'string',
|
||||
coreschema.Boolean: 'boolean',
|
||||
}
|
||||
|
||||
schema = {}
|
||||
if instance.__class__ in CLASS_TO_TYPENAME:
|
||||
schema['type'] = CLASS_TO_TYPENAME[instance.__class__]
|
||||
schema['title'] = instance.title
|
||||
schema['description'] = instance.description
|
||||
if hasattr(instance, 'enum'):
|
||||
schema['enum'] = instance.enum
|
||||
return schema
|
||||
|
||||
def get_parameters(self, link):
|
||||
parameters = []
|
||||
for field in link.fields:
|
||||
if field.location not in ['path', 'query']:
|
||||
continue
|
||||
parameter = {
|
||||
'name': field.name,
|
||||
'in': field.location,
|
||||
}
|
||||
if field.required:
|
||||
parameter['required'] = True
|
||||
if field.description:
|
||||
parameter['description'] = field.description
|
||||
if field.schema:
|
||||
parameter['schema'] = self.get_schema(field.schema)
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
def get_operation(self, link, name, tag):
|
||||
operation_id = "%s_%s" % (tag, name) if tag else name
|
||||
parameters = self.get_parameters(link)
|
||||
|
||||
operation = {
|
||||
'operationId': operation_id,
|
||||
}
|
||||
if link.title:
|
||||
operation['summary'] = link.title
|
||||
if link.description:
|
||||
operation['description'] = link.description
|
||||
if parameters:
|
||||
operation['parameters'] = parameters
|
||||
if tag:
|
||||
operation['tags'] = [tag]
|
||||
return operation
|
||||
|
||||
def get_paths(self, document):
|
||||
paths = {}
|
||||
|
||||
tag = None
|
||||
for name, link in document.links.items():
|
||||
path = parse.urlparse(link.url).path
|
||||
method = link.action.lower()
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method] = self.get_operation(link, name, tag=tag)
|
||||
|
||||
for tag, section in document.data.items():
|
||||
for name, link in section.links.items():
|
||||
path = parse.urlparse(link.url).path
|
||||
method = link.action.lower()
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method] = self.get_operation(link, name, tag=tag)
|
||||
|
||||
return paths
|
||||
|
||||
def get_structure(self, data):
|
||||
return {
|
||||
'openapi': '3.0.0',
|
||||
'info': {
|
||||
'version': '',
|
||||
'title': data.title,
|
||||
'description': data.description
|
||||
},
|
||||
'servers': [{
|
||||
'url': data.url
|
||||
}],
|
||||
'paths': self.get_paths(data)
|
||||
}
|
||||
|
||||
|
||||
class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
media_type = 'application/vnd.oai.openapi'
|
||||
charset = None
|
||||
format = 'openapi'
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.'
|
||||
assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
structure = self.get_structure(data)
|
||||
return yaml.dump(structure, default_flow_style=False).encode()
|
||||
|
||||
|
||||
class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
media_type = 'application/vnd.oai.openapi+json'
|
||||
charset = None
|
||||
format = 'openapi-json'
|
||||
ensure_ascii = not api_settings.UNICODE_JSON
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
structure = self.get_structure(data)
|
||||
return json.dumps(
|
||||
structure, indent=4,
|
||||
ensure_ascii=self.ensure_ascii).encode('utf-8')
|
||||
|
||||
|
||||
class OpenAPIRenderer(BaseRenderer):
|
||||
media_type = 'application/vnd.oai.openapi'
|
||||
charset = None
|
||||
|
@ -1067,6 +879,7 @@ class OpenAPIRenderer(BaseRenderer):
|
|||
class Dumper(yaml.Dumper):
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
|
||||
Dumper.add_representer(SafeString, Dumper.represent_str)
|
||||
Dumper.add_representer(datetime.timedelta, encoders.CustomScalar.represent_timedelta)
|
||||
return yaml.dump(data, default_flow_style=False, sort_keys=False, Dumper=Dumper).encode('utf-8')
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
"""
|
||||
rest_framework.schemas
|
||||
|
||||
schemas:
|
||||
__init__.py
|
||||
generators.py # Top-down schema generation
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
utils.py # Shared helper functions
|
||||
views.py # Houses `SchemaView`, `APIView` subclass.
|
||||
|
||||
We expose a minimal "public" API directly from `schemas`. This covers the
|
||||
basic use-cases:
|
||||
|
||||
from rest_framework.schemas import (
|
||||
AutoSchema,
|
||||
ManualSchema,
|
||||
get_schema_view,
|
||||
SchemaGenerator,
|
||||
)
|
||||
|
||||
Other access should target the submodules directly
|
||||
"""
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from . import coreapi, openapi
|
||||
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
|
||||
from .inspectors import DefaultSchema # noqa
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=None,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
|
||||
version=None):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
if generator_class is None:
|
||||
if coreapi.is_enabled():
|
||||
generator_class = coreapi.SchemaGenerator
|
||||
else:
|
||||
generator_class = openapi.SchemaGenerator
|
||||
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns, version=version
|
||||
)
|
||||
|
||||
# Avoid import cycle on APIView
|
||||
from .views import SchemaView
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
authentication_classes=authentication_classes,
|
||||
permission_classes=permission_classes,
|
||||
)
|
|
@ -1,626 +0,0 @@
|
|||
import warnings
|
||||
from collections import Counter
|
||||
from urllib import parse
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import RemovedInDRF317Warning, exceptions, serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
split_paths = [path.strip('/').split('/') for path in paths]
|
||||
s1 = min(split_paths)
|
||||
s2 = max(split_paths)
|
||||
common = s1
|
||||
for i, c in enumerate(s1):
|
||||
if c != s2[i]:
|
||||
common = s1[:i]
|
||||
break
|
||||
return '/' + '/'.join(common)
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in {
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
}
|
||||
|
||||
|
||||
def distribute_links(obj):
|
||||
for key, value in obj.items():
|
||||
distribute_links(value)
|
||||
|
||||
for preferred_key, link in obj.links:
|
||||
key = obj.get_available_key(preferred_key)
|
||||
obj[key] = link
|
||||
|
||||
|
||||
INSERT_INTO_COLLISION_FMT = """
|
||||
Schema Naming Collision.
|
||||
|
||||
coreapi.Link for URL path {value_url} cannot be inserted into schema.
|
||||
Position conflicts with coreapi.Link for URL path {target_url}.
|
||||
|
||||
Attempted to insert link with keys: {keys}.
|
||||
|
||||
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
|
||||
to customise schema structure.
|
||||
"""
|
||||
|
||||
|
||||
class LinkNode(dict):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
self.methods_counter = Counter()
|
||||
super().__init__()
|
||||
|
||||
def get_available_key(self, preferred_key):
|
||||
if preferred_key not in self:
|
||||
return preferred_key
|
||||
|
||||
while True:
|
||||
current_val = self.methods_counter[preferred_key]
|
||||
self.methods_counter[preferred_key] += 1
|
||||
|
||||
key = f'{preferred_key}_{current_val}'
|
||||
if key not in self:
|
||||
return key
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = LinkNode()
|
||||
target = target[key]
|
||||
|
||||
try:
|
||||
target.links.append((keys[-1], value))
|
||||
except TypeError:
|
||||
msg = INSERT_INTO_COLLISION_FMT.format(
|
||||
value_url=value.url,
|
||||
target_url=target.url,
|
||||
keys=keys
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
"""
|
||||
Original CoreAPI version.
|
||||
"""
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
super().__init__(title, url, description, patterns, urlconf)
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
return links
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
# Method for generating the link layout....
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
mapped_methods = {
|
||||
# Don't count head mapping, e.g. not part of the schema
|
||||
method for method in view.action_map if method != 'head'
|
||||
}
|
||||
if len(mapped_methods) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
Given a list of all paths, return the common prefix which should be
|
||||
discounted when generating a schema structure.
|
||||
|
||||
This will be the longest common string that does not include that last
|
||||
component of the URL, or the last component before a path parameter.
|
||||
|
||||
For example:
|
||||
|
||||
/api/v1/users/
|
||||
/api/v1/users/{pk}/
|
||||
|
||||
The path prefix is '/api/v1'
|
||||
"""
|
||||
prefixes = []
|
||||
for path in paths:
|
||||
components = path.strip('/').split('/')
|
||||
initial_components = []
|
||||
for component in components:
|
||||
if '{' in component:
|
||||
break
|
||||
initial_components.append(component)
|
||||
prefix = '/'.join(initial_components[:-1])
|
||||
if not prefix:
|
||||
# We can just break early in the case that there's at least
|
||||
# one URL that doesn't have a path prefix.
|
||||
return '/'
|
||||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
# View Inspectors #
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_str(field.label) if field.label else ''
|
||||
description = force_str(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.DictField):
|
||||
return coreschema.Object(
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties={
|
||||
key: field_to_schema(value)
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
},
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
related_field_schema = field_to_schema(field.child_relation)
|
||||
|
||||
return coreschema.Array(
|
||||
items=related_field_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
schema_cls = coreschema.String
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
return schema_cls(title=title, description=description)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices)),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
elif isinstance(field, serializers.DateField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date'
|
||||
)
|
||||
elif isinstance(field, serializers.DateTimeField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date-time'
|
||||
)
|
||||
elif isinstance(field, serializers.JSONField):
|
||||
return coreschema.Object(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view introspection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
super().__init__()
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
|
||||
if manual_fields is None:
|
||||
manual_fields = []
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
manual_fields = self.get_manual_fields(path, method)
|
||||
fields = self.update_fields(fields, manual_fields)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_str(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_str(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
|
||||
Override to adjust behaviour for your view.
|
||||
|
||||
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
|
||||
to allow changes based on user experience.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(self.view)
|
||||
return fields
|
||||
|
||||
def get_manual_fields(self, path, method):
|
||||
return self._manual_fields
|
||||
|
||||
@staticmethod
|
||||
def update_fields(fields, update_with):
|
||||
"""
|
||||
Update list of coreapi.Field instances, overwriting on `Field.name`.
|
||||
|
||||
Utility function to handle replacing coreapi.Field fields
|
||||
from a list by name. Used to handle `manual_fields`.
|
||||
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances to update
|
||||
* `update_with: list of `coreapi.Field` instances to add or replace.
|
||||
"""
|
||||
if not update_with:
|
||||
return fields
|
||||
|
||||
by_name = {f.name: f for f in fields}
|
||||
for f in update_with:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = {
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
}
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Allows providing a list of coreapi.Fields,
|
||||
plus an optional description.
|
||||
"""
|
||||
def __init__(self, fields, description='', encoding=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances.
|
||||
* `description`: String description for view. Optional.
|
||||
"""
|
||||
super().__init__()
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
|
||||
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
||||
self._fields = fields
|
||||
self._description = description
|
||||
self._encoding = encoding
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=self._encoding,
|
||||
fields=self._fields,
|
||||
description=self._description
|
||||
)
|
||||
|
||||
|
||||
def is_enabled():
|
||||
"""Is CoreAPI Mode enabled?"""
|
||||
if coreapi is not None:
|
||||
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
|
||||
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)
|
|
@ -1,239 +0,0 @@
|
|||
"""
|
||||
generators.py # Top-down schema generation
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
from django.urls import URLPattern, URLResolver
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
|
||||
|
||||
def get_pk_name(model):
|
||||
meta = model._meta.concrete_model._meta
|
||||
return _get_pk(meta).name
|
||||
|
||||
|
||||
def is_api_view(callback):
|
||||
"""
|
||||
Return `True` if the given view callback is a REST framework view/viewset.
|
||||
"""
|
||||
# Avoid import cycle on APIView
|
||||
from rest_framework.views import APIView
|
||||
cls = getattr(callback, 'cls', None)
|
||||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
'GET': 0,
|
||||
'POST': 1,
|
||||
'PUT': 2,
|
||||
'PATCH': 3,
|
||||
'DELETE': 4
|
||||
}.get(method, 5)
|
||||
return (method_priority,)
|
||||
|
||||
|
||||
_PATH_PARAMETER_COMPONENT_RE = re.compile(
|
||||
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
|
||||
)
|
||||
|
||||
|
||||
class EndpointEnumerator:
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urlconf = settings.ROOT_URLCONF
|
||||
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, str):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + str(pattern.pattern)
|
||||
if isinstance(pattern, URLPattern):
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, URLResolver):
|
||||
nested_endpoints = self.get_api_endpoints(
|
||||
patterns=pattern.url_patterns,
|
||||
prefix=path_regex
|
||||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
return sorted(api_endpoints, key=endpoint_ordering)
|
||||
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
# ???: Would it be feasible to adjust this such that we generate the
|
||||
# path, plus the kwargs, plus the type from the converter, such that we
|
||||
# could feed that straight into the parameter schema object?
|
||||
|
||||
path = simplify_regex(path_regex)
|
||||
|
||||
# Strip Django 2.0 converters as they are incompatible with uritemplate format
|
||||
return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
|
||||
|
||||
def should_include_endpoint(self, path, callback):
|
||||
"""
|
||||
Return `True` if the given endpoint should be included.
|
||||
"""
|
||||
if not is_api_view(callback):
|
||||
return False # Ignore anything except REST framework views.
|
||||
|
||||
if callback.cls.schema is None:
|
||||
return False
|
||||
|
||||
if 'schema' in callback.initkwargs:
|
||||
if callback.initkwargs['schema'] is None:
|
||||
return False
|
||||
|
||||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False # Ignore .json style URLs.
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
"""
|
||||
Return a list of the valid HTTP methods for this endpoint.
|
||||
"""
|
||||
if hasattr(callback, 'actions'):
|
||||
actions = set(callback.actions)
|
||||
http_method_names = set(callback.cls.http_method_names)
|
||||
methods = [method.upper() for method in actions & http_method_names]
|
||||
else:
|
||||
methods = callback.cls().allowed_methods
|
||||
|
||||
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
self.urlconf = urlconf
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.version = version
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def _initialise_endpoints(self):
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
def _get_paths_and_endpoints(self, request):
|
||||
"""
|
||||
Generate (path, method, view) given (path, method, callback) for paths.
|
||||
"""
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
path = self.coerce_path(path, method, view)
|
||||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
return paths, view_endpoints
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except (exceptions.APIException, Http404, PermissionDenied):
|
||||
return False
|
||||
return True
|
|
@ -1,126 +0,0 @@
|
|||
"""
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from django.utils.encoding import smart_str
|
||||
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
|
||||
class ViewInspector:
|
||||
"""
|
||||
Descriptor class on APIView.
|
||||
|
||||
Provide subclass for per-view schema generation
|
||||
"""
|
||||
|
||||
# Used in _get_description_section()
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
def __init__(self):
|
||||
self.instance_schemas = WeakKeyDictionary()
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Enables `ViewInspector` as a Python _Descriptor_.
|
||||
|
||||
This is how `view.schema` knows about `view`.
|
||||
|
||||
`__get__` is called when the descriptor is accessed on the owner.
|
||||
(That will be when view.schema is called in our case.)
|
||||
|
||||
`owner` is always the owner class. (An APIView, or subclass for us.)
|
||||
`instance` is the view instance or `None` if accessed from the class,
|
||||
rather than an instance.
|
||||
|
||||
See: https://docs.python.org/3/howto/descriptor.html for info on
|
||||
descriptor usage.
|
||||
"""
|
||||
if instance in self.instance_schemas:
|
||||
return self.instance_schemas[instance]
|
||||
|
||||
self.view = instance
|
||||
return self
|
||||
|
||||
def __set__(self, instance, other):
|
||||
self.instance_schemas[instance] = other
|
||||
if other is not None:
|
||||
other.view = instance
|
||||
|
||||
@property
|
||||
def view(self):
|
||||
"""View property."""
|
||||
assert self._view is not None, (
|
||||
"Schema generation REQUIRES a view instance. (Hint: you accessed "
|
||||
"`schema` from the view class rather than an instance.)"
|
||||
)
|
||||
return self._view
|
||||
|
||||
@view.setter
|
||||
def view(self, value):
|
||||
self._view = value
|
||||
|
||||
@view.deleter
|
||||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a path description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_func = getattr(view, method_name, None)
|
||||
method_docstring = method_func.__doc__
|
||||
if method_func and method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return self._get_description_section(view, method.lower(), formatting.dedent(smart_str(method_docstring)))
|
||||
else:
|
||||
return self._get_description_section(view, getattr(view, 'action', method.lower()),
|
||||
view.get_view_description())
|
||||
|
||||
def _get_description_section(self, view, header, description):
|
||||
lines = description.splitlines()
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if self.header_regex.match(line):
|
||||
current_section, separator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
|
||||
class DefaultSchema(ViewInspector):
|
||||
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
|
||||
def __get__(self, instance, owner):
|
||||
result = super().__get__(instance, owner)
|
||||
if not isinstance(result, DefaultSchema):
|
||||
return result
|
||||
|
||||
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
|
||||
assert issubclass(inspector_class, ViewInspector), (
|
||||
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
|
||||
)
|
||||
inspector = inspector_class()
|
||||
inspector.view = instance
|
||||
return inspector
|
|
@ -1,721 +0,0 @@
|
|||
import re
|
||||
import warnings
|
||||
from decimal import Decimal
|
||||
from operator import attrgetter
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
|
||||
)
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import exceptions, renderers, serializers
|
||||
from rest_framework.compat import inflection, uritemplate
|
||||
from rest_framework.fields import _UnvalidatedField, empty
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
|
||||
def get_info(self):
|
||||
# Title and version are required by openapi specification 3.x
|
||||
info = {
|
||||
'title': self.title or '',
|
||||
'version': self.version or ''
|
||||
}
|
||||
|
||||
if self.description is not None:
|
||||
info['description'] = self.description
|
||||
|
||||
return info
|
||||
|
||||
def check_duplicate_operation_id(self, paths):
|
||||
ids = {}
|
||||
for route in paths:
|
||||
for method in paths[route]:
|
||||
if 'operationId' not in paths[route][method]:
|
||||
continue
|
||||
operation_id = paths[route][method]['operationId']
|
||||
if operation_id in ids:
|
||||
warnings.warn(
|
||||
'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
|
||||
'\tRoute: {route1}, Method: {method1}\n'
|
||||
'\tRoute: {route2}, Method: {method2}\n'
|
||||
'\tAn operationId has to be unique across your schema. Your schema may not work in other tools.'
|
||||
.format(
|
||||
route1=ids[operation_id]['route'],
|
||||
method1=ids[operation_id]['method'],
|
||||
route2=route,
|
||||
method2=method,
|
||||
operation_id=operation_id
|
||||
)
|
||||
)
|
||||
ids[operation_id] = {
|
||||
'route': route,
|
||||
'method': method
|
||||
}
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a OpenAPI schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
components_schemas = {}
|
||||
|
||||
# Iterate endpoints generating per method path operations.
|
||||
paths = {}
|
||||
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
|
||||
operation = view.schema.get_operation(path, method)
|
||||
components = view.schema.get_components(path, method)
|
||||
for k in components.keys():
|
||||
if k not in components_schemas:
|
||||
continue
|
||||
if components_schemas[k] == components[k]:
|
||||
continue
|
||||
warnings.warn(f'Schema component "{k}" has been overridden with a different value.')
|
||||
|
||||
components_schemas.update(components)
|
||||
|
||||
# Normalise path for any provided mount url.
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
path = urljoin(self.url or '/', path)
|
||||
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method.lower()] = operation
|
||||
|
||||
self.check_duplicate_operation_id(paths)
|
||||
|
||||
# Compile final schema.
|
||||
schema = {
|
||||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
}
|
||||
|
||||
if len(components_schemas) > 0:
|
||||
schema['components'] = {
|
||||
'schemas': components_schemas
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
# View Inspectors
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
|
||||
def __init__(self, tags=None, operation_id_base=None, component_name=None):
|
||||
"""
|
||||
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
|
||||
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
|
||||
"""
|
||||
if tags and not all(isinstance(tag, str) for tag in tags):
|
||||
raise ValueError('tags must be a list or tuple of string.')
|
||||
self._tags = tags
|
||||
self.operation_id_base = operation_id_base
|
||||
self.component_name = component_name
|
||||
super().__init__()
|
||||
|
||||
request_media_types = []
|
||||
response_media_types = []
|
||||
|
||||
method_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partialUpdate',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
operation['operationId'] = self.get_operation_id(path, method)
|
||||
operation['description'] = self.get_description(path, method)
|
||||
|
||||
parameters = []
|
||||
parameters += self.get_path_parameters(path, method)
|
||||
parameters += self.get_pagination_parameters(path, method)
|
||||
parameters += self.get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
|
||||
request_body = self.get_request_body(path, method)
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self.get_responses(path, method)
|
||||
operation['tags'] = self.get_tags(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
def get_component_name(self, serializer):
|
||||
"""
|
||||
Compute the component's name from the serializer.
|
||||
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
|
||||
"""
|
||||
if self.component_name is not None:
|
||||
return self.component_name
|
||||
|
||||
# use the serializer's class name as the component name.
|
||||
component_name = serializer.__class__.__name__
|
||||
# We remove the "serializer" string from the class name.
|
||||
pattern = re.compile("serializer", re.IGNORECASE)
|
||||
component_name = pattern.sub("", component_name)
|
||||
|
||||
if component_name == "":
|
||||
raise Exception(
|
||||
'"{}" is an invalid class name for schema generation. '
|
||||
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
|
||||
.format(serializer.__class__.__name__)
|
||||
)
|
||||
|
||||
return component_name
|
||||
|
||||
def get_components(self, path, method):
|
||||
"""
|
||||
Return components with their properties from the serializer.
|
||||
"""
|
||||
|
||||
if method.lower() == 'delete':
|
||||
return {}
|
||||
|
||||
request_serializer = self.get_request_serializer(path, method)
|
||||
response_serializer = self.get_response_serializer(path, method)
|
||||
|
||||
components = {}
|
||||
|
||||
if isinstance(request_serializer, serializers.Serializer):
|
||||
component_name = self.get_component_name(request_serializer)
|
||||
content = self.map_serializer(request_serializer)
|
||||
components.setdefault(component_name, content)
|
||||
|
||||
if isinstance(response_serializer, serializers.Serializer):
|
||||
component_name = self.get_component_name(response_serializer)
|
||||
content = self.map_serializer(response_serializer)
|
||||
components.setdefault(component_name, content)
|
||||
|
||||
return components
|
||||
|
||||
def _to_camel_case(self, snake_str):
|
||||
components = snake_str.split('_')
|
||||
# We capitalize the first letter of each component except the first one
|
||||
# with the 'title' method and join them together.
|
||||
return components[0] + ''.join(x.title() for x in components[1:])
|
||||
|
||||
def get_operation_id_base(self, path, method, action):
|
||||
"""
|
||||
Compute the base part for operation ID from the model, serializer or view name.
|
||||
"""
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
|
||||
if self.operation_id_base is not None:
|
||||
name = self.operation_id_base
|
||||
|
||||
# Try to deduce the ID from the view's model
|
||||
elif model is not None:
|
||||
name = model.__name__
|
||||
|
||||
# Try with the serializer class name
|
||||
elif self.get_serializer(path, method) is not None:
|
||||
name = self.get_serializer(path, method).__class__.__name__
|
||||
if name.endswith('Serializer'):
|
||||
name = name[:-10]
|
||||
|
||||
# Fallback to the view name
|
||||
else:
|
||||
name = self.view.__class__.__name__
|
||||
if name.endswith('APIView'):
|
||||
name = name[:-7]
|
||||
elif name.endswith('View'):
|
||||
name = name[:-4]
|
||||
|
||||
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
|
||||
# comes at the end of the name
|
||||
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
|
||||
name = name[:-len(action)]
|
||||
|
||||
if action == 'list':
|
||||
assert inflection, '`inflection` must be installed for OpenAPI schema support.'
|
||||
name = inflection.pluralize(name)
|
||||
|
||||
return name
|
||||
|
||||
def get_operation_id(self, path, method):
|
||||
"""
|
||||
Compute an operation ID from the view type and get_operation_id_base method.
|
||||
"""
|
||||
method_name = getattr(self.view, 'action', method.lower())
|
||||
if is_list_view(path, method, self.view):
|
||||
action = 'list'
|
||||
elif method_name not in self.method_mapping:
|
||||
action = self._to_camel_case(method_name)
|
||||
else:
|
||||
action = self.method_mapping[method.lower()]
|
||||
|
||||
name = self.get_operation_id_base(path, method, action)
|
||||
|
||||
return action + name
|
||||
|
||||
def get_path_parameters(self, path, method):
|
||||
"""
|
||||
Return a list of parameters from templated path variables.
|
||||
"""
|
||||
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
|
||||
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
parameters = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
if model is not None: # TODO: test this.
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_str(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
'schema': {
|
||||
'type': 'string', # TODO: integer, pattern, ...
|
||||
},
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def get_filter_parameters(self, path, method):
|
||||
if not self.allows_filters(path, method):
|
||||
return []
|
||||
parameters = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
parameters += filter_backend().get_schema_operation_parameters(self.view)
|
||||
return parameters
|
||||
|
||||
def allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_pagination_parameters(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
paginator = self.get_paginator()
|
||||
if not paginator:
|
||||
return []
|
||||
|
||||
return paginator.get_schema_operation_parameters(view)
|
||||
|
||||
def map_choicefield(self, field):
|
||||
choices = list(dict.fromkeys(field.choices)) # preserve order and remove duplicates
|
||||
if all(isinstance(choice, bool) for choice in choices):
|
||||
type = 'boolean'
|
||||
elif all(isinstance(choice, int) for choice in choices):
|
||||
type = 'integer'
|
||||
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
|
||||
type = 'number'
|
||||
elif all(isinstance(choice, str) for choice in choices):
|
||||
type = 'string'
|
||||
else:
|
||||
type = None
|
||||
|
||||
mapping = {
|
||||
# The value of `enum` keyword MUST be an array and SHOULD be unique.
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
|
||||
'enum': choices
|
||||
}
|
||||
|
||||
# If We figured out `type` then and only then we should set it. It must be a string.
|
||||
# Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
|
||||
# It is optional but it can not be null.
|
||||
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
|
||||
if type:
|
||||
mapping['type'] = type
|
||||
return mapping
|
||||
|
||||
def map_field(self, field):
|
||||
|
||||
# Nested Serializers, `many` or not.
|
||||
if isinstance(field, serializers.ListSerializer):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_serializer(field.child)
|
||||
}
|
||||
if isinstance(field, serializers.Serializer):
|
||||
data = self.map_serializer(field)
|
||||
data['type'] = 'object'
|
||||
return data
|
||||
|
||||
# Related fields.
|
||||
if isinstance(field, serializers.ManyRelatedField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_field(field.child_relation)
|
||||
}
|
||||
if isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
if getattr(field, "pk_field", False):
|
||||
return self.map_field(field=field.pk_field)
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
return {'type': 'integer'}
|
||||
|
||||
# ChoiceFields (single and multiple).
|
||||
# Q:
|
||||
# - Is 'type' required?
|
||||
# - can we determine the TYPE of a choicefield?
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self.map_choicefield(field)
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
return self.map_choicefield(field)
|
||||
|
||||
# ListField.
|
||||
if isinstance(field, serializers.ListField):
|
||||
mapping = {
|
||||
'type': 'array',
|
||||
'items': {},
|
||||
}
|
||||
if not isinstance(field.child, _UnvalidatedField):
|
||||
mapping['items'] = self.map_field(field.child)
|
||||
return mapping
|
||||
|
||||
# DateField and DateTimeField type is string
|
||||
if isinstance(field, serializers.DateField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date',
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.DateTimeField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date-time',
|
||||
}
|
||||
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
|
||||
if isinstance(field, serializers.EmailField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'email'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.URLField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uri'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.UUIDField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uuid'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.IPAddressField):
|
||||
content = {
|
||||
'type': 'string',
|
||||
}
|
||||
if field.protocol != 'both':
|
||||
content['format'] = field.protocol
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.DecimalField):
|
||||
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
|
||||
content = {
|
||||
'type': 'string',
|
||||
'format': 'decimal',
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
'type': 'number'
|
||||
}
|
||||
|
||||
if field.decimal_places:
|
||||
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
|
||||
if field.max_whole_digits:
|
||||
content['maximum'] = int(field.max_whole_digits * '9') + 1
|
||||
content['minimum'] = -content['maximum']
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.FloatField):
|
||||
content = {
|
||||
'type': 'number',
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.IntegerField):
|
||||
content = {
|
||||
'type': 'integer'
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
# 2147483647 is max for int32_size, so we use int64 for format
|
||||
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
|
||||
content['format'] = 'int64'
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.FileField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'binary'
|
||||
}
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.JSONField: 'object',
|
||||
serializers.DictField: 'object',
|
||||
serializers.HStoreField: 'object',
|
||||
}
|
||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||
|
||||
def _map_min_max(self, field, content):
|
||||
if field.max_value:
|
||||
content['maximum'] = field.max_value
|
||||
if field.min_value:
|
||||
content['minimum'] = field.min_value
|
||||
|
||||
def map_serializer(self, serializer):
|
||||
# Assuming we have a valid serializer instance.
|
||||
required = []
|
||||
properties = {}
|
||||
|
||||
for field in serializer.fields.values():
|
||||
if isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
if field.required and not serializer.partial:
|
||||
required.append(self.get_field_name(field))
|
||||
|
||||
schema = self.map_field(field)
|
||||
if field.read_only:
|
||||
schema['readOnly'] = True
|
||||
if field.write_only:
|
||||
schema['writeOnly'] = True
|
||||
if field.allow_null:
|
||||
schema['nullable'] = True
|
||||
if field.default is not None and field.default != empty and not callable(field.default):
|
||||
schema['default'] = field.default
|
||||
if field.help_text:
|
||||
schema['description'] = str(field.help_text)
|
||||
self.map_field_validators(field, schema)
|
||||
|
||||
properties[self.get_field_name(field)] = schema
|
||||
|
||||
result = {
|
||||
'type': 'object',
|
||||
'properties': properties
|
||||
}
|
||||
if required:
|
||||
result['required'] = required
|
||||
|
||||
return result
|
||||
|
||||
def map_field_validators(self, field, schema):
|
||||
"""
|
||||
map field validators
|
||||
"""
|
||||
for v in field.validators:
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
if isinstance(v, EmailValidator):
|
||||
schema['format'] = 'email'
|
||||
if isinstance(v, URLValidator):
|
||||
schema['format'] = 'uri'
|
||||
if isinstance(v, RegexValidator):
|
||||
# In Python, the token \Z does what \z does in other engines.
|
||||
# https://stackoverflow.com/questions/53283160
|
||||
schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z')
|
||||
elif isinstance(v, MaxLengthValidator):
|
||||
attr_name = 'maxLength'
|
||||
if isinstance(field, serializers.ListField):
|
||||
attr_name = 'maxItems'
|
||||
schema[attr_name] = v.limit_value
|
||||
elif isinstance(v, MinLengthValidator):
|
||||
attr_name = 'minLength'
|
||||
if isinstance(field, serializers.ListField):
|
||||
attr_name = 'minItems'
|
||||
schema[attr_name] = v.limit_value
|
||||
elif isinstance(v, MaxValueValidator):
|
||||
schema['maximum'] = v.limit_value
|
||||
elif isinstance(v, MinValueValidator):
|
||||
schema['minimum'] = v.limit_value
|
||||
elif isinstance(v, DecimalValidator) and \
|
||||
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
|
||||
if v.decimal_places:
|
||||
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
|
||||
if v.max_digits:
|
||||
digits = v.max_digits
|
||||
if v.decimal_places is not None and v.decimal_places > 0:
|
||||
digits -= v.decimal_places
|
||||
schema['maximum'] = int(digits * '9') + 1
|
||||
schema['minimum'] = -schema['maximum']
|
||||
|
||||
def get_field_name(self, field):
|
||||
"""
|
||||
Override this method if you want to change schema field name.
|
||||
For example, convert snake_case field name to camelCase.
|
||||
"""
|
||||
return field.field_name
|
||||
|
||||
def get_paginator(self):
|
||||
pagination_class = getattr(self.view, 'pagination_class', None)
|
||||
if pagination_class:
|
||||
return pagination_class()
|
||||
return None
|
||||
|
||||
def map_parsers(self, path, method):
|
||||
return list(map(attrgetter('media_type'), self.view.parser_classes))
|
||||
|
||||
def map_renderers(self, path, method):
|
||||
media_types = []
|
||||
for renderer in self.view.renderer_classes:
|
||||
# BrowsableAPIRenderer not relevant to OpenAPI spec
|
||||
if issubclass(renderer, renderers.BrowsableAPIRenderer):
|
||||
continue
|
||||
media_types.append(renderer.media_type)
|
||||
return media_types
|
||||
|
||||
def get_serializer(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return None
|
||||
|
||||
try:
|
||||
return view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
return None
|
||||
|
||||
def get_request_serializer(self, path, method):
|
||||
"""
|
||||
Override this method if your view uses a different serializer for
|
||||
handling request body.
|
||||
"""
|
||||
return self.get_serializer(path, method)
|
||||
|
||||
def get_response_serializer(self, path, method):
|
||||
"""
|
||||
Override this method if your view uses a different serializer for
|
||||
populating response data.
|
||||
"""
|
||||
return self.get_serializer(path, method)
|
||||
|
||||
def get_reference(self, serializer):
|
||||
return {'$ref': f'#/components/schemas/{self.get_component_name(serializer)}'}
|
||||
|
||||
def get_request_body(self, path, method):
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return {}
|
||||
|
||||
self.request_media_types = self.map_parsers(path, method)
|
||||
|
||||
serializer = self.get_request_serializer(path, method)
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
item_schema = {}
|
||||
else:
|
||||
item_schema = self.get_reference(serializer)
|
||||
|
||||
return {
|
||||
'content': {
|
||||
ct: {'schema': item_schema}
|
||||
for ct in self.request_media_types
|
||||
}
|
||||
}
|
||||
|
||||
def get_responses(self, path, method):
|
||||
if method == 'DELETE':
|
||||
return {
|
||||
'204': {
|
||||
'description': ''
|
||||
}
|
||||
}
|
||||
|
||||
self.response_media_types = self.map_renderers(path, method)
|
||||
|
||||
serializer = self.get_response_serializer(path, method)
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
item_schema = {}
|
||||
else:
|
||||
item_schema = self.get_reference(serializer)
|
||||
|
||||
if is_list_view(path, method, self.view):
|
||||
response_schema = {
|
||||
'type': 'array',
|
||||
'items': item_schema,
|
||||
}
|
||||
paginator = self.get_paginator()
|
||||
if paginator:
|
||||
response_schema = paginator.get_paginated_response_schema(response_schema)
|
||||
else:
|
||||
response_schema = item_schema
|
||||
status_code = '201' if method == 'POST' else '200'
|
||||
return {
|
||||
status_code: {
|
||||
'content': {
|
||||
ct: {'schema': response_schema}
|
||||
for ct in self.response_media_types
|
||||
},
|
||||
# description is a mandatory property,
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
|
||||
# TODO: put something meaningful into it
|
||||
'description': ""
|
||||
}
|
||||
}
|
||||
|
||||
def get_tags(self, path, method):
|
||||
# If user have specified tags, use them.
|
||||
if self._tags:
|
||||
return self._tags
|
||||
|
||||
# First element of a specific path could be valid tag. This is a fallback solution.
|
||||
# PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile]
|
||||
# POST, GET(List): /user_profile/ tags = [user-profile]
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return [path.split('/')[0].replace('_', '-')]
|
|
@ -1,41 +0,0 @@
|
|||
"""
|
||||
utils.py # Shared helper functions
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework.mixins import RetrieveModelMixin
|
||||
|
||||
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have an explicitly defined action, which we can inspect.
|
||||
return view.action == 'list'
|
||||
|
||||
if method.lower() != 'get':
|
||||
return False
|
||||
if isinstance(view, RetrieveModelMixin):
|
||||
return False
|
||||
path_components = path.strip('/').split('/')
|
||||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
|
@ -1,48 +0,0 @@
|
|||
"""
|
||||
views.py # Houses `SchemaView`, `APIView` subclass.
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from rest_framework import exceptions, renderers
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import coreapi
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
schema = None # exclude from schema
|
||||
renderer_classes = None
|
||||
schema_generator = None
|
||||
public = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
if coreapi.is_enabled():
|
||||
self.renderer_classes = [
|
||||
renderers.CoreAPIOpenAPIRenderer,
|
||||
renderers.CoreJSONRenderer
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [
|
||||
renderers.OpenAPIRenderer,
|
||||
renderers.JSONOpenAPIRenderer,
|
||||
]
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes += [renderers.BrowsableAPIRenderer]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = self.schema_generator.get_schema(request, self.public)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
def handle_exception(self, exc):
|
||||
# Schema renderers do not render exceptions, so re-perform content
|
||||
# negotiation with default renderers.
|
||||
self.renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
|
||||
neg = self.perform_content_negotiation(self.request, force=True)
|
||||
self.request.accepted_renderer, self.request.accepted_media_type = neg
|
||||
return super().handle_exception(exc)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,20 +0,0 @@
|
|||
import pytest
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework import renderers
|
||||
from rest_framework.schemas import coreapi, get_schema_view, openapi
|
||||
|
||||
|
||||
class GetSchemaViewTests(TestCase):
|
||||
"""For the get_schema_view() helper."""
|
||||
def test_openapi(self):
|
||||
schema_view = get_schema_view(title="With OpenAPI")
|
||||
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
|
||||
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes
|
||||
|
||||
@pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed')
|
||||
def test_coreapi(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
schema_view = get_schema_view(title="With CoreAPI")
|
||||
assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator)
|
||||
assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes
|
|
@ -1,154 +0,0 @@
|
|||
import io
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.urls import path
|
||||
|
||||
from rest_framework.compat import coreapi, uritemplate, yaml
|
||||
from rest_framework.management.commands import generateschema
|
||||
from rest_framework.utils import formatting, json
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class FooView(APIView):
|
||||
def get(self, request):
|
||||
pass
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path('', FooView.as_view())
|
||||
]
|
||||
|
||||
|
||||
class CustomSchemaGenerator:
|
||||
SCHEMA = {"key": "value"}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_schema(self, **kwargs):
|
||||
return self.SCHEMA
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed')
|
||||
class GenerateSchemaTests(TestCase):
|
||||
"""Tests for management command generateschema."""
|
||||
|
||||
def setUp(self):
|
||||
self.out = io.StringIO()
|
||||
|
||||
def test_command_detects_schema_generation_mode(self):
|
||||
"""Switching between CoreAPI & OpenAPI"""
|
||||
command = generateschema.Command()
|
||||
assert command.get_mode() == generateschema.OPENAPI_MODE
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
assert command.get_mode() == generateschema.COREAPI_MODE
|
||||
|
||||
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
|
||||
def test_renders_default_schema_with_custom_title_url_and_description(self):
|
||||
call_command('generateschema',
|
||||
'--title=ExampleAPI',
|
||||
'--url=http://api.example.com',
|
||||
'--description=Example description',
|
||||
stdout=self.out)
|
||||
# Check valid YAML was output.
|
||||
schema = yaml.safe_load(self.out.getvalue())
|
||||
assert schema['openapi'] == '3.0.2'
|
||||
|
||||
def test_renders_openapi_json_schema(self):
|
||||
call_command('generateschema',
|
||||
'--format=openapi-json',
|
||||
stdout=self.out)
|
||||
# Check valid JSON was output.
|
||||
out_json = json.loads(self.out.getvalue())
|
||||
assert out_json['openapi'] == '3.0.2'
|
||||
|
||||
def test_accepts_custom_schema_generator(self):
|
||||
call_command('generateschema',
|
||||
f'--generator_class={__name__}.{CustomSchemaGenerator.__name__}',
|
||||
stdout=self.out)
|
||||
out_json = yaml.safe_load(self.out.getvalue())
|
||||
assert out_json == CustomSchemaGenerator.SCHEMA
|
||||
|
||||
def test_writes_schema_to_file_on_parameter(self):
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
call_command('generateschema', f'--file={path}', stdout=self.out)
|
||||
# nothing on stdout
|
||||
assert not self.out.getvalue()
|
||||
|
||||
call_command('generateschema', stdout=self.out)
|
||||
expected_out = self.out.getvalue()
|
||||
# file output identical to stdout output
|
||||
with os.fdopen(fd) as fh:
|
||||
assert expected_out and fh.read() == expected_out
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
|
||||
@pytest.mark.skipif(coreapi is None, reason='coreapi is required.')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self):
|
||||
expected_out = """info:
|
||||
description: Example description
|
||||
title: ExampleAPI
|
||||
version: ''
|
||||
openapi: 3.0.0
|
||||
paths:
|
||||
/:
|
||||
get:
|
||||
operationId: list
|
||||
servers:
|
||||
- url: http://api.example.com/
|
||||
"""
|
||||
call_command('generateschema',
|
||||
'--title=ExampleAPI',
|
||||
'--url=http://api.example.com',
|
||||
'--description=Example description',
|
||||
stdout=self.out)
|
||||
|
||||
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
|
||||
|
||||
@pytest.mark.skipif(coreapi is None, reason='coreapi is required.')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_coreapi_renders_openapi_json_schema(self):
|
||||
expected_out = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"version": "",
|
||||
"title": "",
|
||||
"description": ""
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
"url": ""
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/": {
|
||||
"get": {
|
||||
"operationId": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
call_command('generateschema',
|
||||
'--format=openapi-json',
|
||||
stdout=self.out)
|
||||
out_json = json.loads(self.out.getvalue())
|
||||
|
||||
self.assertDictEqual(out_json, expected_out)
|
||||
|
||||
@pytest.mark.skipif(coreapi is None, reason='coreapi is required.')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_renders_corejson_schema(self):
|
||||
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
|
||||
call_command('generateschema',
|
||||
'--format=corejson',
|
||||
stdout=self.out)
|
||||
self.assertIn(expected_out, self.out.getvalue())
|
File diff suppressed because it is too large
Load Diff
|
@ -1,250 +0,0 @@
|
|||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, MaxLengthValidator, MaxValueValidator,
|
||||
MinLengthValidator, MinValueValidator, RegexValidator
|
||||
)
|
||||
from django.db import models
|
||||
|
||||
from rest_framework import generics, permissions, serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas.openapi import AutoSchema
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet, ViewSet
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DocStringExampleListView(APIView):
|
||||
"""
|
||||
get: A description of my GET operation.
|
||||
post: A description of my POST operation.
|
||||
"""
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DocStringExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
"""
|
||||
A description of my GET operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Generics.
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
date = serializers.DateField()
|
||||
datetime = serializers.DateTimeField()
|
||||
duration = serializers.DurationField(default=timedelta())
|
||||
hstore = serializers.HStoreField()
|
||||
uuid_field = serializers.UUIDField(default=uuid.uuid4)
|
||||
|
||||
|
||||
class ExampleGenericAPIView(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleGenericViewSet(GenericViewSet):
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
@action(detail=False)
|
||||
def new(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@action(detail=False)
|
||||
def old(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Validators and/or equivalent Field attributes.
|
||||
class ExampleValidatedSerializer(serializers.Serializer):
|
||||
integer = serializers.IntegerField(
|
||||
validators=(
|
||||
MaxValueValidator(limit_value=99),
|
||||
MinValueValidator(limit_value=-11),
|
||||
)
|
||||
)
|
||||
string = serializers.CharField(
|
||||
validators=(
|
||||
MaxLengthValidator(limit_value=10),
|
||||
MinLengthValidator(limit_value=2),
|
||||
)
|
||||
)
|
||||
regex = serializers.CharField(
|
||||
validators=(
|
||||
RegexValidator(regex=r'[ABC]12{3}'),
|
||||
),
|
||||
help_text='must have an A, B, or C followed by 1222'
|
||||
)
|
||||
lst = serializers.ListField(
|
||||
validators=(
|
||||
MaxLengthValidator(limit_value=10),
|
||||
MinLengthValidator(limit_value=2),
|
||||
)
|
||||
)
|
||||
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
|
||||
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
|
||||
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
|
||||
decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
|
||||
decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
|
||||
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
|
||||
decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2)
|
||||
email = serializers.EmailField(default='foo@bar.com')
|
||||
url = serializers.URLField(default='http://www.example.com', allow_null=True)
|
||||
uuid = serializers.UUIDField()
|
||||
ip4 = serializers.IPAddressField(protocol='ipv4')
|
||||
ip6 = serializers.IPAddressField(protocol='ipv6')
|
||||
ip = serializers.IPAddressField()
|
||||
duration = serializers.DurationField(
|
||||
validators=(
|
||||
MinValueValidator(timedelta(seconds=10)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExampleValidatedAPIView(generics.GenericAPIView):
|
||||
serializer_class = ExampleValidatedSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55,
|
||||
decimal2=5.33, email='a@b.co',
|
||||
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
|
||||
ip='192.168.1.1')
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
# Serializer with model.
|
||||
class OpenAPIExample(models.Model):
|
||||
first_name = models.CharField(max_length=30)
|
||||
|
||||
|
||||
class ExampleSerializerModel(serializers.Serializer):
|
||||
date = serializers.DateField()
|
||||
datetime = serializers.DateTimeField()
|
||||
hstore = serializers.HStoreField()
|
||||
uuid_field = serializers.UUIDField(default=uuid.uuid4)
|
||||
|
||||
class Meta:
|
||||
model = OpenAPIExample
|
||||
|
||||
|
||||
class ExampleOperationIdDuplicate1(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializerModel
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleOperationIdDuplicate2(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializerModel
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleGenericAPIViewModel(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializerModel
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleAutoSchemaComponentName(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializerModel
|
||||
schema = AutoSchema(component_name="Ulysses")
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleAutoSchemaDuplicate1(generics.GenericAPIView):
|
||||
serializer_class = ExampleValidatedSerializer
|
||||
schema = AutoSchema(component_name="Duplicate")
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleAutoSchemaDuplicate2(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializerModel
|
||||
schema = AutoSchema(component_name="Duplicate")
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleViewSet(ViewSet):
|
||||
serializer_class = ExampleSerializerModel
|
||||
|
||||
def list(self, request):
|
||||
pass
|
||||
|
||||
def create(self, request):
|
||||
pass
|
||||
|
||||
def retrieve(self, request, pk=None):
|
||||
pass
|
||||
|
||||
def update(self, request, pk=None):
|
||||
pass
|
||||
|
||||
def partial_update(self, request, pk=None):
|
||||
pass
|
||||
|
||||
def destroy(self, request, pk=None):
|
||||
pass
|
|
@ -5,18 +5,16 @@ import pytest
|
|||
from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.http.request import HttpRequest
|
||||
from django.template import loader
|
||||
from django.test import TestCase, override_settings
|
||||
from django.urls import include, path, re_path
|
||||
from django.utils.safestring import SafeText
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import permissions, serializers, status
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.renderers import (
|
||||
AdminRenderer, BaseRenderer, BrowsableAPIRenderer, DocumentationRenderer,
|
||||
HTMLFormRenderer, JSONRenderer, SchemaJSRenderer, StaticHTMLRenderer
|
||||
AdminRenderer, BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer,
|
||||
JSONRenderer, StaticHTMLRenderer
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
@ -871,61 +869,3 @@ class AdminRendererTests(TestCase):
|
|||
self.assertEqual(results[1]['url'], '/example')
|
||||
self.assertEqual(results[2]['url'], None)
|
||||
self.assertNotIn('url', results[3])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
class TestDocumentationRenderer(TestCase):
|
||||
|
||||
def test_document_with_link_named_data(self):
|
||||
"""
|
||||
Ref #5395: Doc's `document.data` would fail with a Link named "data".
|
||||
As per #4972, use templatetag instead.
|
||||
"""
|
||||
document = coreapi.Document(
|
||||
title='Data Endpoint API',
|
||||
url='https://api.example.org/',
|
||||
content={
|
||||
'data': coreapi.Link(
|
||||
url='/data/',
|
||||
action='get',
|
||||
fields=[],
|
||||
description='Return data.'
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
factory = APIRequestFactory()
|
||||
request = factory.get('/')
|
||||
|
||||
renderer = DocumentationRenderer()
|
||||
|
||||
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
|
||||
assert '<h1>Data Endpoint API</h1>' in html
|
||||
|
||||
def test_shell_code_example_rendering(self):
|
||||
template = loader.get_template('rest_framework/docs/langs/shell.html')
|
||||
context = {
|
||||
'document': coreapi.Document(url='https://api.example.org/'),
|
||||
'link_key': 'testcases > list',
|
||||
'link': coreapi.Link(url='/data/', action='get', fields=[]),
|
||||
}
|
||||
html = template.render(context)
|
||||
assert 'testcases<span class="w"> </span>list' in html
|
||||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
class TestSchemaJSRenderer(TestCase):
|
||||
|
||||
def test_schemajs_output(self):
|
||||
"""
|
||||
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
|
||||
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
|
||||
"""
|
||||
factory = APIRequestFactory()
|
||||
request = factory.get('/')
|
||||
|
||||
renderer = SchemaJSRenderer()
|
||||
|
||||
output = renderer.render('data', renderer_context={"request": request})
|
||||
assert "'ImRhdGEi'" in output
|
||||
assert "'b'ImRhdGEi''" not in output
|
||||
|
|
Loading…
Reference in New Issue
Block a user