mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-24 00:04:16 +03:00
Added OpenAPI Schema Generation. (#6532)
Co-authored-by: Lucidiot <lucidiot@protonmail.com> Co-authored-by: dongfangtianyu <dongfangtianyu@qq.com>
This commit is contained in:
parent
a91e6a0e69
commit
37f210a455
|
@ -37,6 +37,9 @@ class BaseFilterBackend:
|
|||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
|
@ -156,6 +159,19 @@ class SearchFilter(BaseFilterBackend):
|
|||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.search_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.search_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the ordering.
|
||||
|
@ -287,6 +303,19 @@ class OrderingFilter(BaseFilterBackend):
|
|||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.ordering_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.ordering_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class DjangoObjectPermissionsFilter(BaseFilterBackend):
|
||||
"""
|
||||
|
|
|
@ -1,41 +1,56 @@
|
|||
from django.core.management.base import BaseCommand
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.renderers import (
|
||||
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer
|
||||
)
|
||||
from rest_framework.schemas.generators import SchemaGenerator
|
||||
from rest_framework import renderers
|
||||
from rest_framework.schemas import coreapi
|
||||
from rest_framework.schemas.openapi import SchemaGenerator
|
||||
|
||||
OPENAPI_MODE = 'openapi'
|
||||
COREAPI_MODE = 'coreapi'
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generates configured API schema for project."
|
||||
|
||||
def get_mode(self):
|
||||
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--title', dest="title", default=None, type=str)
|
||||
parser.add_argument('--title', dest="title", default='', type=str)
|
||||
parser.add_argument('--url', dest="url", default=None, type=str)
|
||||
parser.add_argument('--description', dest="description", default=None, type=str)
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
|
||||
else:
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
assert coreapi is not None, 'coreapi must be installed.'
|
||||
|
||||
generator = SchemaGenerator(
|
||||
generator_class = self.get_generator_class()
|
||||
generator = generator_class(
|
||||
url=options['url'],
|
||||
title=options['title'],
|
||||
description=options['description']
|
||||
)
|
||||
|
||||
schema = generator.get_schema(request=None, public=True)
|
||||
|
||||
renderer = self.get_renderer(options['format'])
|
||||
output = renderer.render(schema, renderer_context={})
|
||||
self.stdout.write(output.decode())
|
||||
|
||||
def get_renderer(self, format):
|
||||
renderer_cls = {
|
||||
'corejson': CoreJSONRenderer,
|
||||
'openapi': OpenAPIRenderer,
|
||||
'openapi-json': JSONOpenAPIRenderer,
|
||||
}[format]
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
renderer_cls = {
|
||||
'corejson': renderers.CoreJSONRenderer,
|
||||
'openapi': renderers.CoreAPIOpenAPIRenderer,
|
||||
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
renderer_cls = {
|
||||
'openapi': renderers.OpenAPIRenderer,
|
||||
'openapi-json': renderers.JSONOpenAPIRenderer,
|
||||
}[format]
|
||||
return renderer_cls()
|
||||
|
||||
def get_generator_class(self):
|
||||
if self.get_mode() == COREAPI_MODE:
|
||||
return coreapi.SchemaGenerator
|
||||
return SchemaGenerator
|
||||
|
|
|
@ -148,6 +148,9 @@ class BasePagination:
|
|||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class PageNumberPagination(BasePagination):
|
||||
"""
|
||||
|
@ -301,6 +304,32 @@ class PageNumberPagination(BasePagination):
|
|||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.page_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
class LimitOffsetPagination(BasePagination):
|
||||
"""
|
||||
|
@ -430,6 +459,15 @@ class LimitOffsetPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
Determine an object count, supporting either querysets or regular lists.
|
||||
"""
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
|
@ -454,14 +492,28 @@ class LimitOffsetPagination(BasePagination):
|
|||
)
|
||||
]
|
||||
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
Determine an object count, supporting either querysets or regular lists.
|
||||
"""
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.limit_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.limit_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': self.offset_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.offset_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
return parameters
|
||||
|
||||
|
||||
class CursorPagination(BasePagination):
|
||||
|
@ -816,3 +868,29 @@ class CursorPagination(BasePagination):
|
|||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.cursor_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.cursor_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
}
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
}
|
||||
)
|
||||
return parameters
|
||||
|
|
|
@ -1013,28 +1013,49 @@ class _BaseOpenAPIRenderer:
|
|||
}
|
||||
|
||||
|
||||
class OpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
media_type = 'application/vnd.oai.openapi'
|
||||
charset = None
|
||||
format = 'openapi'
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.'
|
||||
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
|
||||
assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.'
|
||||
assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
structure = self.get_structure(data)
|
||||
return yaml.dump(structure, default_flow_style=False).encode()
|
||||
|
||||
|
||||
class JSONOpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer):
|
||||
media_type = 'application/vnd.oai.openapi+json'
|
||||
charset = None
|
||||
format = 'openapi-json'
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.'
|
||||
assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
structure = self.get_structure(data)
|
||||
return json.dumps(structure, indent=4).encode()
|
||||
return json.dumps(structure, indent=4).encode('utf-8')
|
||||
|
||||
|
||||
class OpenAPIRenderer(BaseRenderer):
|
||||
media_type = 'application/vnd.oai.openapi'
|
||||
charset = None
|
||||
format = 'openapi'
|
||||
|
||||
def __init__(self):
|
||||
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
return yaml.dump(data, default_flow_style=False).encode('utf-8')
|
||||
|
||||
|
||||
class JSONOpenAPIRenderer(BaseRenderer):
|
||||
media_type = 'application/vnd.oai.openapi+json'
|
||||
charset = None
|
||||
format = 'openapi-json'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
return json.dumps(data, indent=2).encode('utf-8')
|
||||
|
|
|
@ -22,24 +22,32 @@ Other access should target the submodules directly
|
|||
"""
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from .generators import SchemaGenerator
|
||||
from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa
|
||||
from . import coreapi, openapi
|
||||
from .inspectors import DefaultSchema # noqa
|
||||
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=SchemaGenerator,
|
||||
public=False, patterns=None, generator_class=None,
|
||||
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
|
||||
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
# Avoid import cycle on APIView
|
||||
from .views import SchemaView
|
||||
if generator_class is None:
|
||||
if coreapi.is_enabled():
|
||||
generator_class = coreapi.SchemaGenerator
|
||||
else:
|
||||
generator_class = openapi.SchemaGenerator
|
||||
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns,
|
||||
)
|
||||
|
||||
# Avoid import cycle on APIView
|
||||
from .views import SchemaView
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
|
|
616
rest_framework/schemas/coreapi.py
Normal file
616
rest_framework/schemas/coreapi.py
Normal file
|
@ -0,0 +1,616 @@
|
|||
import re
|
||||
import warnings
|
||||
from collections import Counter, OrderedDict
|
||||
from urllib import parse
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
# Used in _get_description_section()
|
||||
# TODO: ???: move up to base.
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
# Generator #
|
||||
# TODO: Pull some of this into base.
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in {
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
}
|
||||
|
||||
|
||||
def distribute_links(obj):
|
||||
for key, value in obj.items():
|
||||
distribute_links(value)
|
||||
|
||||
for preferred_key, link in obj.links:
|
||||
key = obj.get_available_key(preferred_key)
|
||||
obj[key] = link
|
||||
|
||||
|
||||
INSERT_INTO_COLLISION_FMT = """
|
||||
Schema Naming Collision.
|
||||
|
||||
coreapi.Link for URL path {value_url} cannot be inserted into schema.
|
||||
Position conflicts with coreapi.Link for URL path {target_url}.
|
||||
|
||||
Attempted to insert link with keys: {keys}.
|
||||
|
||||
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
|
||||
to customise schema structure.
|
||||
"""
|
||||
|
||||
|
||||
class LinkNode(OrderedDict):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
self.methods_counter = Counter()
|
||||
super(LinkNode, self).__init__()
|
||||
|
||||
def get_available_key(self, preferred_key):
|
||||
if preferred_key not in self:
|
||||
return preferred_key
|
||||
|
||||
while True:
|
||||
current_val = self.methods_counter[preferred_key]
|
||||
self.methods_counter[preferred_key] += 1
|
||||
|
||||
key = '{}_{}'.format(preferred_key, current_val)
|
||||
if key not in self:
|
||||
return key
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = LinkNode()
|
||||
target = target[key]
|
||||
|
||||
try:
|
||||
target.links.append((keys[-1], value))
|
||||
except TypeError:
|
||||
msg = INSERT_INTO_COLLISION_FMT.format(
|
||||
value_url=value.url,
|
||||
target_url=target.url,
|
||||
keys=keys
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
"""
|
||||
Original CoreAPI version.
|
||||
"""
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
return links
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
# Method for generating the link layout....
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
if len(view.action_map) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
# View Inspectors #
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.DictField):
|
||||
return coreschema.Object(
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
related_field_schema = field_to_schema(field.child_relation)
|
||||
|
||||
return coreschema.Array(
|
||||
items=related_field_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
schema_cls = coreschema.String
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
return schema_cls(title=title, description=description)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices)),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
elif isinstance(field, serializers.DateField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date'
|
||||
)
|
||||
elif isinstance(field, serializers.DateTimeField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date-time'
|
||||
)
|
||||
elif isinstance(field, serializers.JSONField):
|
||||
return coreschema.Object(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view introspection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
super(AutoSchema, self).__init__()
|
||||
if manual_fields is None:
|
||||
manual_fields = []
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
manual_fields = self.get_manual_fields(path, method)
|
||||
fields = self.update_fields(fields, manual_fields)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
|
||||
else:
|
||||
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
|
||||
|
||||
def _get_description_section(self, view, header, description):
|
||||
lines = [line for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
|
||||
Override to adjust behaviour for your view.
|
||||
|
||||
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
|
||||
to allow changes based on user experience.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(self.view)
|
||||
return fields
|
||||
|
||||
def get_manual_fields(self, path, method):
|
||||
return self._manual_fields
|
||||
|
||||
@staticmethod
|
||||
def update_fields(fields, update_with):
|
||||
"""
|
||||
Update list of coreapi.Field instances, overwriting on `Field.name`.
|
||||
|
||||
Utility function to handle replacing coreapi.Field fields
|
||||
from a list by name. Used to handle `manual_fields`.
|
||||
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances to update
|
||||
* `update_with: list of `coreapi.Field` instances to add or replace.
|
||||
"""
|
||||
if not update_with:
|
||||
return fields
|
||||
|
||||
by_name = OrderedDict((f.name, f) for f in fields)
|
||||
for f in update_with:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = {
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
}
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Allows providing a list of coreapi.Fields,
|
||||
plus an optional description.
|
||||
"""
|
||||
def __init__(self, fields, description='', encoding=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances.
|
||||
* `description`: String description for view. Optional.
|
||||
"""
|
||||
super(ManualSchema, self).__init__()
|
||||
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
||||
self._fields = fields
|
||||
self._description = description
|
||||
self._encoding = encoding
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=self._encoding,
|
||||
fields=self._fields,
|
||||
description=self._description
|
||||
)
|
||||
|
||||
|
||||
def is_enabled():
|
||||
"""Is CoreAPI Mode enabled?"""
|
||||
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)
|
|
@ -4,7 +4,6 @@ generators.py # Top-down schema generation
|
|||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from collections import Counter, OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
@ -13,15 +12,11 @@ from django.core.exceptions import PermissionDenied
|
|||
from django.http import Http404
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.compat import (
|
||||
URLPattern, URLResolver, coreapi, coreschema, get_original_route
|
||||
)
|
||||
from rest_framework.compat import URLPattern, URLResolver, get_original_route
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
|
||||
from .utils import is_list_view
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
split_paths = [path.strip('/').split('/') for path in paths]
|
||||
|
@ -50,78 +45,6 @@ def is_api_view(callback):
|
|||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
INSERT_INTO_COLLISION_FMT = """
|
||||
Schema Naming Collision.
|
||||
|
||||
coreapi.Link for URL path {value_url} cannot be inserted into schema.
|
||||
Position conflicts with coreapi.Link for URL path {target_url}.
|
||||
|
||||
Attempted to insert link with keys: {keys}.
|
||||
|
||||
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
|
||||
to customise schema structure.
|
||||
"""
|
||||
|
||||
|
||||
class LinkNode(OrderedDict):
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
self.methods_counter = Counter()
|
||||
super().__init__()
|
||||
|
||||
def get_available_key(self, preferred_key):
|
||||
if preferred_key not in self:
|
||||
return preferred_key
|
||||
|
||||
while True:
|
||||
current_val = self.methods_counter[preferred_key]
|
||||
self.methods_counter[preferred_key] += 1
|
||||
|
||||
key = '{}_{}'.format(preferred_key, current_val)
|
||||
if key not in self:
|
||||
return key
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = LinkNode()
|
||||
target = target[key]
|
||||
|
||||
try:
|
||||
target.links.append((keys[-1], value))
|
||||
except TypeError:
|
||||
msg = INSERT_INTO_COLLISION_FMT.format(
|
||||
value_url=value.url,
|
||||
target_url=target.url,
|
||||
keys=keys
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def distribute_links(obj):
|
||||
for key, value in obj.items():
|
||||
distribute_links(value)
|
||||
|
||||
for preferred_key, link in obj.links:
|
||||
key = obj.get_available_key(preferred_key)
|
||||
obj[key] = link
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in {
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
}
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
|
@ -190,6 +113,10 @@ class EndpointEnumerator:
|
|||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
# ???: Would it be feasible to adjust this such that we generate the
|
||||
# path, plus the kwargs, plus the type from the convertor, such that we
|
||||
# could feed that straight into the parameter schema object?
|
||||
|
||||
path = simplify_regex(path_regex)
|
||||
|
||||
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
|
||||
|
@ -228,35 +155,18 @@ class EndpointEnumerator:
|
|||
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
|
||||
|
||||
|
||||
class SchemaGenerator:
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
class BaseSchemaGenerator(object):
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
|
@ -266,36 +176,15 @@ class SchemaGenerator:
|
|||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
def _initialise_endpoints(self):
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
def get_links(self, request=None):
|
||||
def _get_paths_and_endpoints(self, request):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
Generate (path, method, view) given (path, method, callback) for paths.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
|
@ -304,22 +193,48 @@ class SchemaGenerator:
|
|||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
return paths, view_endpoints
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
return links
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
|
@ -352,29 +267,6 @@ class SchemaGenerator:
|
|||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
|
@ -387,64 +279,3 @@ class SchemaGenerator:
|
|||
except (exceptions.APIException, Http404, PermissionDenied):
|
||||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
if len(view.action_map) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
|
|
@ -3,125 +3,9 @@ inspectors.py # Per-endpoint view introspection
|
|||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from urllib import parse
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
from .utils import is_list_view
|
||||
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.DictField):
|
||||
return coreschema.Object(
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
related_field_schema = field_to_schema(field.child_relation)
|
||||
|
||||
return coreschema.Array(
|
||||
items=related_field_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
schema_cls = coreschema.String
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
return schema_cls(title=title, description=description)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices)),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
elif isinstance(field, serializers.DateField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date'
|
||||
)
|
||||
elif isinstance(field, serializers.DateTimeField):
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='date-time'
|
||||
)
|
||||
elif isinstance(field, serializers.JSONField):
|
||||
return coreschema.Object(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
||||
|
||||
class ViewInspector:
|
||||
|
@ -178,320 +62,6 @@ class ViewInspector:
|
|||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view introspection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
super().__init__()
|
||||
if manual_fields is None:
|
||||
manual_fields = []
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
manual_fields = self.get_manual_fields(path, method)
|
||||
fields = self.update_fields(fields, manual_fields)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
|
||||
else:
|
||||
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
|
||||
|
||||
def _get_description_section(self, view, header, description):
|
||||
lines = [line for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
|
||||
Override to adjust behaviour for your view.
|
||||
|
||||
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
|
||||
to allow changes based on user experience.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(self.view)
|
||||
return fields
|
||||
|
||||
def get_manual_fields(self, path, method):
|
||||
return self._manual_fields
|
||||
|
||||
@staticmethod
|
||||
def update_fields(fields, update_with):
|
||||
"""
|
||||
Update list of coreapi.Field instances, overwriting on `Field.name`.
|
||||
|
||||
Utility function to handle replacing coreapi.Field fields
|
||||
from a list by name. Used to handle `manual_fields`.
|
||||
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances to update
|
||||
* `update_with: list of `coreapi.Field` instances to add or replace.
|
||||
"""
|
||||
if not update_with:
|
||||
return fields
|
||||
|
||||
by_name = OrderedDict((f.name, f) for f in fields)
|
||||
for f in update_with:
|
||||
by_name[f.name] = f
|
||||
return list(by_name.values())
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = {
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
}
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Allows providing a list of coreapi.Fields,
|
||||
plus an optional description.
|
||||
"""
|
||||
def __init__(self, fields, description='', encoding=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances.
|
||||
* `description`: String description for view. Optional.
|
||||
"""
|
||||
super().__init__()
|
||||
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
||||
self._fields = fields
|
||||
self._description = description
|
||||
self._encoding = encoding
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=parse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=self._encoding,
|
||||
fields=self._fields,
|
||||
description=self._description
|
||||
)
|
||||
|
||||
|
||||
class DefaultSchema(ViewInspector):
|
||||
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
|
||||
|
|
377
rest_framework/schemas/openapi.py
Normal file
377
rest_framework/schemas/openapi.py
Normal file
|
@ -0,0 +1,377 @@
|
|||
import warnings
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import uritemplate
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
# Generator
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
|
||||
def get_info(self):
|
||||
info = {
|
||||
'title': self.title,
|
||||
'version': 'TODO',
|
||||
}
|
||||
|
||||
if self.description is not None:
|
||||
info['description'] = self.description
|
||||
|
||||
return info
|
||||
|
||||
def get_paths(self, request=None):
|
||||
result = {}
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
operation = view.schema.get_operation(path, method)
|
||||
subpath = '/' + path[len(prefix):]
|
||||
result.setdefault(subpath, {})
|
||||
result[subpath][method.lower()] = operation
|
||||
|
||||
return result
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a OpenAPI schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
paths = self.get_paths(None if public else request)
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
schema = {
|
||||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
# View Inspectors
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
|
||||
content_types = ['application/json']
|
||||
method_mapping = {
|
||||
'get': 'Retrieve',
|
||||
'post': 'Create',
|
||||
'put': 'Update',
|
||||
'patch': 'PartialUpdate',
|
||||
'delete': 'Destroy',
|
||||
}
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
operation['operationId'] = self._get_operation_id(path, method)
|
||||
|
||||
parameters = []
|
||||
parameters += self._get_path_parameters(path, method)
|
||||
parameters += self._get_pagination_parameters(path, method)
|
||||
parameters += self._get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
|
||||
request_body = self._get_request_body(path, method)
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
def _get_operation_id(self, path, method):
|
||||
"""
|
||||
Compute an operation ID from the model, serializer or view name.
|
||||
"""
|
||||
method_name = getattr(self.view, 'action', method.lower())
|
||||
if is_list_view(path, method, self.view):
|
||||
action = 'List'
|
||||
elif method_name not in self.method_mapping:
|
||||
action = method_name
|
||||
else:
|
||||
action = self.method_mapping[method.lower()]
|
||||
|
||||
# Try to deduce the ID from the view's model
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
if model is not None:
|
||||
name = model.__name__
|
||||
|
||||
# Try with the serializer class name
|
||||
elif hasattr(self.view, 'get_serializer_class'):
|
||||
name = self.view.get_serializer_class().__name__
|
||||
if name.endswith('Serializer'):
|
||||
name = name[:-10]
|
||||
|
||||
# Fallback to the view name
|
||||
else:
|
||||
name = self.view.__class__.__name__
|
||||
if name.endswith('APIView'):
|
||||
name = name[:-7]
|
||||
elif name.endswith('View'):
|
||||
name = name[:-4]
|
||||
if name.endswith(action): # ListView, UpdateAPIView, ThingDelete ...
|
||||
name = name[:-len(action)]
|
||||
|
||||
if action == 'List' and not name.endswith('s'): # ListThings instead of ListThing
|
||||
name += 's'
|
||||
|
||||
return action + name
|
||||
|
||||
def _get_path_parameters(self, path, method):
|
||||
"""
|
||||
Return a list of parameters from templated path variables.
|
||||
"""
|
||||
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
|
||||
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
parameters = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
if model is not None: # TODO: test this.
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
'schema': {
|
||||
'type': 'string', # TODO: integer, pattern, ...
|
||||
},
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def _get_filter_parameters(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
parameters = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
parameters += filter_backend().get_schema_operation_parameters(self.view)
|
||||
return parameters
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def _get_pagination_parameters(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_operation_parameters(view)
|
||||
|
||||
def _map_field(self, field):
|
||||
|
||||
# Nested Serializers, `many` or not.
|
||||
if isinstance(field, serializers.ListSerializer):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._map_serializer(field.child)
|
||||
}
|
||||
if isinstance(field, serializers.Serializer):
|
||||
data = self._map_serializer(field)
|
||||
data['type'] = 'object'
|
||||
return data
|
||||
|
||||
# Related fields.
|
||||
if isinstance(field, serializers.ManyRelatedField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._map_field(field.child_relation)
|
||||
}
|
||||
if isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
return {'type': 'integer'}
|
||||
|
||||
# ChoiceFields (single and multiple).
|
||||
# Q:
|
||||
# - Is 'type' required?
|
||||
# - can we determine the TYPE of a choicefield?
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'enum': list(field.choices)
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
return {
|
||||
'enum': list(field.choices),
|
||||
}
|
||||
|
||||
# ListField.
|
||||
if isinstance(field, serializers.ListField):
|
||||
return {
|
||||
'type': 'array',
|
||||
}
|
||||
|
||||
# DateField and DateTimeField type is string
|
||||
if isinstance(field, serializers.DateField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date',
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.DateTimeField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'date-time',
|
||||
}
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.IntegerField: 'integer',
|
||||
|
||||
serializers.JSONField: 'object',
|
||||
serializers.DictField: 'object',
|
||||
}
|
||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||
|
||||
def _map_serializer(self, serializer):
|
||||
# Assuming we have a valid serializer instance.
|
||||
# TODO:
|
||||
# - field is Nested or List serializer.
|
||||
# - Handle read_only/write_only for request/response differences.
|
||||
# - could do this with readOnly/writeOnly and then filter dict.
|
||||
required = []
|
||||
properties = {}
|
||||
|
||||
for field in serializer.fields.values():
|
||||
if isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
if field.required:
|
||||
required.append(field.field_name)
|
||||
|
||||
schema = self._map_field(field)
|
||||
if field.read_only:
|
||||
schema['readOnly'] = True
|
||||
if field.write_only:
|
||||
schema['writeOnly'] = True
|
||||
if field.allow_null:
|
||||
schema['nullable'] = True
|
||||
|
||||
properties[field.field_name] = schema
|
||||
return {
|
||||
'required': required,
|
||||
'properties': properties,
|
||||
}
|
||||
|
||||
def _get_request_body(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return {}
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return {}
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return {}
|
||||
|
||||
content = self._map_serializer(serializer)
|
||||
# No required fields for PATCH
|
||||
if method == 'PATCH':
|
||||
del content['required']
|
||||
# No read_only fields for request.
|
||||
for name, schema in content['properties'].copy().items():
|
||||
if 'readOnly' in schema:
|
||||
del content['properties'][name]
|
||||
|
||||
return {
|
||||
'content': {
|
||||
ct: {'schema': content}
|
||||
for ct in self.content_types
|
||||
}
|
||||
}
|
||||
|
||||
def _get_responses(self, path, method):
|
||||
# TODO: Handle multiple codes.
|
||||
content = {}
|
||||
view = self.view
|
||||
if hasattr(view, 'get_serializer'):
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.Serializer):
|
||||
content = self._map_serializer(serializer)
|
||||
# No write_only fields for response.
|
||||
for name, schema in content['properties'].copy().items():
|
||||
if 'writeOnly' in schema:
|
||||
del content['properties'][name]
|
||||
content['required'] = [f for f in content['required'] if f != name]
|
||||
|
||||
return {
|
||||
'200': {
|
||||
'content': {
|
||||
ct: {'schema': content}
|
||||
for ct in self.content_types
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,6 +3,9 @@ utils.py # Shared helper functions
|
|||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from django.db import models
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.mixins import RetrieveModelMixin
|
||||
|
||||
|
||||
|
@ -22,3 +25,17 @@ def is_list_view(path, method, view):
|
|||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ See schemas.__init__.py for package overview.
|
|||
"""
|
||||
from rest_framework import exceptions, renderers
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import coreapi
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
@ -19,10 +20,16 @@ class SchemaView(APIView):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
self.renderer_classes = [
|
||||
renderers.OpenAPIRenderer,
|
||||
renderers.CoreJSONRenderer
|
||||
]
|
||||
if coreapi.is_enabled():
|
||||
self.renderer_classes = [
|
||||
renderers.CoreAPIOpenAPIRenderer,
|
||||
renderers.CoreJSONRenderer
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [
|
||||
renderers.OpenAPIRenderer,
|
||||
renderers.JSONOpenAPIRenderer,
|
||||
]
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes += [renderers.BrowsableAPIRenderer]
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ DEFAULTS = {
|
|||
'DEFAULT_FILTER_BACKENDS': (),
|
||||
|
||||
# Schema
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
|
||||
|
||||
# Throttling
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
|
|
0
tests/schemas/__init__.py
Normal file
0
tests/schemas/__init__.py
Normal file
|
@ -16,15 +16,16 @@ from rest_framework.routers import DefaultRouter, SimpleRouter
|
|||
from rest_framework.schemas import (
|
||||
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
|
||||
)
|
||||
from rest_framework.schemas.coreapi import field_to_schema
|
||||
from rest_framework.schemas.generators import EndpointEnumerator
|
||||
from rest_framework.schemas.inspectors import field_to_schema
|
||||
from rest_framework.schemas.utils import is_list_view
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.utils import formatting
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from .models import BasicModel, ForeignKeySource, ManyToManySource
|
||||
from . import views
|
||||
from ..models import BasicModel, ForeignKeySource, ManyToManySource
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
@ -133,11 +134,12 @@ class ExampleViewSet(ModelViewSet):
|
|||
pass
|
||||
|
||||
|
||||
if coreapi:
|
||||
schema_view = get_schema_view(title='Example API')
|
||||
else:
|
||||
def schema_view(request):
|
||||
pass
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
if coreapi:
|
||||
schema_view = get_schema_view(title='Example API')
|
||||
else:
|
||||
def schema_view(request):
|
||||
pass
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register('example', ExampleViewSet, basename='example')
|
||||
|
@ -148,7 +150,7 @@ urlpatterns = [
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_schemas')
|
||||
@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestRouterGeneratedSchema(TestCase):
|
||||
def test_anonymous_request(self):
|
||||
client = APIClient()
|
||||
|
@ -400,12 +402,13 @@ class ExampleDetailView(APIView):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGenerator(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^example/?$', ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -453,12 +456,13 @@ class TestSchemaGenerator(TestCase):
|
|||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@unittest.skipUnless(path, 'needs Django 2')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorDjango2(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
path('example/', ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', ExampleDetailView.as_view()),
|
||||
path('example/', views.ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', views.ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -505,12 +509,13 @@ class TestSchemaGeneratorDjango2(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorNotAtRoot(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^api/v1/example/?$', ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -558,6 +563,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
|
@ -622,13 +628,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
router.register('example1', Http404ExampleViewSet, basename='example1')
|
||||
router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
|
||||
self.patterns = [
|
||||
url('^example/?$', ExampleListView.as_view()),
|
||||
url('^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^', include(router.urls))
|
||||
]
|
||||
|
||||
|
@ -668,6 +675,7 @@ class ForeignKeySourceView(generics.CreateAPIView):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithForeignKey(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
|
@ -713,6 +721,7 @@ class ManyToManySourceView(generics.CreateAPIView):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithManyToMany(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
|
@ -747,6 +756,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class Test4605Regression(TestCase):
|
||||
def test_4605_regression(self):
|
||||
generator = SchemaGenerator()
|
||||
|
@ -762,6 +772,7 @@ class CustomViewInspector(AutoSchema):
|
|||
pass
|
||||
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestAutoSchema(TestCase):
|
||||
|
||||
def test_apiview_schema_descriptor(self):
|
||||
|
@ -777,7 +788,7 @@ class TestAutoSchema(TestCase):
|
|||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_set_custom_inspector_class_via_settings(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}):
|
||||
view = APIView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
|
@ -971,6 +982,7 @@ class TestAutoSchema(TestCase):
|
|||
self.assertEqual(field_to_schema(case[0]), case[1])
|
||||
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_docstring_is_not_stripped_by_get_description():
|
||||
class ExampleDocstringAPIView(APIView):
|
||||
"""
|
||||
|
@ -1007,25 +1019,25 @@ def test_docstring_is_not_stripped_by_get_description():
|
|||
|
||||
|
||||
# Views for SchemaGenerationExclusionTests
|
||||
class ExcludedAPIView(APIView):
|
||||
schema = None
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
class ExcludedAPIView(APIView):
|
||||
schema = None
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class SchemaGenerationExclusionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
|
@ -1078,11 +1090,6 @@ class SchemaGenerationExclusionTests(TestCase):
|
|||
assert should_include == expected
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
|
@ -1118,11 +1125,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
|
|||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestURLNamingCollisions(TestCase):
|
||||
"""
|
||||
Ref: https://github.com/encode/django-rest-framework/issues/4704
|
||||
"""
|
||||
def test_manually_routing_nested_routes(self):
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url(r'^test', simple_fbv),
|
||||
url(r'^test/list/', simple_fbv),
|
||||
|
@ -1228,6 +1240,10 @@ class TestURLNamingCollisions(TestCase):
|
|||
|
||||
def test_url_under_same_key_not_replaced_another(self):
|
||||
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url(r'^test/list/', simple_fbv),
|
||||
url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
|
||||
|
@ -1302,10 +1318,8 @@ def test_head_and_options_methods_are_excluded():
|
|||
assert inspector.get_allowed_methods(callback) == ["GET"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
class TestAutoSchemaAllowsFilters:
|
||||
class MockAPIView(APIView):
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
class MockAPIView(APIView):
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
|
||||
def _test(self, method):
|
||||
view = self.MockAPIView()
|
20
tests/schemas/test_get_schema_view.py
Normal file
20
tests/schemas/test_get_schema_view.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework import renderers
|
||||
from rest_framework.schemas import coreapi, get_schema_view, openapi
|
||||
|
||||
|
||||
class GetSchemaViewTests(TestCase):
|
||||
"""For the get_schema_view() helper."""
|
||||
def test_openapi(self):
|
||||
schema_view = get_schema_view(title="With OpenAPI")
|
||||
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
|
||||
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes
|
||||
|
||||
@pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed')
|
||||
def test_coreapi(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
schema_view = get_schema_view(title="With CoreAPI")
|
||||
assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator)
|
||||
assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes
|
|
@ -6,7 +6,8 @@ from django.core.management import call_command
|
|||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.compat import uritemplate, yaml
|
||||
from rest_framework.management.commands import generateschema
|
||||
from rest_framework.utils import formatting, json
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
@ -21,15 +22,43 @@ urlpatterns = [
|
|||
]
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF='tests.test_generateschema')
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed')
|
||||
class GenerateSchemaTests(TestCase):
|
||||
"""Tests for management command generateschema."""
|
||||
|
||||
def setUp(self):
|
||||
self.out = io.StringIO()
|
||||
|
||||
def test_command_detects_schema_generation_mode(self):
|
||||
"""Switching between CoreAPI & OpenAPI"""
|
||||
command = generateschema.Command()
|
||||
assert command.get_mode() == generateschema.OPENAPI_MODE
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
|
||||
assert command.get_mode() == generateschema.COREAPI_MODE
|
||||
|
||||
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
|
||||
def test_renders_default_schema_with_custom_title_url_and_description(self):
|
||||
call_command('generateschema',
|
||||
'--title=SampleAPI',
|
||||
'--url=http://api.sample.com',
|
||||
'--description=Sample description',
|
||||
stdout=self.out)
|
||||
# Check valid YAML was output.
|
||||
schema = yaml.load(self.out.getvalue())
|
||||
assert schema['openapi'] == '3.0.2'
|
||||
|
||||
def test_renders_openapi_json_schema(self):
|
||||
call_command('generateschema',
|
||||
'--format=openapi-json',
|
||||
stdout=self.out)
|
||||
# Check valid JSON was output.
|
||||
out_json = json.loads(self.out.getvalue())
|
||||
assert out_json['openapi'] == '3.0.2'
|
||||
|
||||
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self):
|
||||
expected_out = """info:
|
||||
description: Sample description
|
||||
title: SampleAPI
|
||||
|
@ -50,7 +79,8 @@ class GenerateSchemaTests(TestCase):
|
|||
|
||||
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
|
||||
|
||||
def test_renders_openapi_json_schema(self):
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_coreapi_renders_openapi_json_schema(self):
|
||||
expected_out = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
|
@ -78,6 +108,7 @@ class GenerateSchemaTests(TestCase):
|
|||
|
||||
self.assertDictEqual(out_json, expected_out)
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_renders_corejson_schema(self):
|
||||
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
|
||||
call_command('generateschema',
|
245
tests/schemas/test_openapi.py
Normal file
245
tests/schemas/test_openapi.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
import pytest
|
||||
from django.conf.urls import url
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
|
||||
from rest_framework import filters, generics, pagination, routers, serializers
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
|
||||
|
||||
from . import views
|
||||
|
||||
|
||||
def create_request(path):
|
||||
factory = RequestFactory()
|
||||
request = Request(factory.get(path))
|
||||
return request
|
||||
|
||||
|
||||
def create_view(view_cls, method, request):
|
||||
generator = SchemaGenerator()
|
||||
view = generator.create_view(view_cls.as_view(), method, request)
|
||||
return view
|
||||
|
||||
|
||||
class TestBasics(TestCase):
|
||||
def dummy_view(request):
|
||||
pass
|
||||
|
||||
def test_filters(self):
|
||||
classes = [filters.SearchFilter, filters.OrderingFilter]
|
||||
for c in classes:
|
||||
f = c()
|
||||
assert f.get_schema_operation_parameters(self.dummy_view)
|
||||
|
||||
def test_pagination(self):
|
||||
classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination]
|
||||
for c in classes:
|
||||
f = c()
|
||||
assert f.get_schema_operation_parameters(self.dummy_view)
|
||||
|
||||
|
||||
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
|
||||
class TestOperationIntrospection(TestCase):
|
||||
|
||||
def test_path_without_parameters(self):
|
||||
path = '/example/'
|
||||
method = 'GET'
|
||||
|
||||
view = create_view(
|
||||
views.ExampleListView,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
operation = inspector.get_operation(path, method)
|
||||
assert operation == {
|
||||
'operationId': 'ListExamples',
|
||||
'parameters': [],
|
||||
'responses': {'200': {'content': {'application/json': {'schema': {}}}}},
|
||||
}
|
||||
|
||||
def test_path_with_id_parameter(self):
|
||||
path = '/example/{id}/'
|
||||
method = 'GET'
|
||||
|
||||
view = create_view(
|
||||
views.ExampleDetailView,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
parameters = inspector._get_path_parameters(path, method)
|
||||
assert parameters == [{
|
||||
'description': '',
|
||||
'in': 'path',
|
||||
'name': 'id',
|
||||
'required': True,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
}]
|
||||
|
||||
def test_request_body(self):
|
||||
path = '/'
|
||||
method = 'POST'
|
||||
|
||||
class Serializer(serializers.Serializer):
|
||||
text = serializers.CharField()
|
||||
read_only = serializers.CharField(read_only=True)
|
||||
|
||||
class View(generics.GenericAPIView):
|
||||
serializer_class = Serializer
|
||||
|
||||
view = create_view(
|
||||
View,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
request_body = inspector._get_request_body(path, method)
|
||||
assert request_body['content']['application/json']['schema']['required'] == ['text']
|
||||
assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text']
|
||||
|
||||
def test_response_body_generation(self):
|
||||
path = '/'
|
||||
method = 'POST'
|
||||
|
||||
class Serializer(serializers.Serializer):
|
||||
text = serializers.CharField()
|
||||
write_only = serializers.CharField(write_only=True)
|
||||
|
||||
class View(generics.GenericAPIView):
|
||||
serializer_class = Serializer
|
||||
|
||||
view = create_view(
|
||||
View,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
responses = inspector._get_responses(path, method)
|
||||
assert responses['200']['content']['application/json']['schema']['required'] == ['text']
|
||||
assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text']
|
||||
|
||||
def test_response_body_nested_serializer(self):
|
||||
path = '/'
|
||||
method = 'POST'
|
||||
|
||||
class NestedSerializer(serializers.Serializer):
|
||||
number = serializers.IntegerField()
|
||||
|
||||
class Serializer(serializers.Serializer):
|
||||
text = serializers.CharField()
|
||||
nested = NestedSerializer()
|
||||
|
||||
class View(generics.GenericAPIView):
|
||||
serializer_class = Serializer
|
||||
|
||||
view = create_view(
|
||||
View,
|
||||
method,
|
||||
create_request(path),
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
responses = inspector._get_responses(path, method)
|
||||
schema = responses['200']['content']['application/json']['schema']
|
||||
assert sorted(schema['required']) == ['nested', 'text']
|
||||
assert sorted(list(schema['properties'].keys())) == ['nested', 'text']
|
||||
assert schema['properties']['nested']['type'] == 'object'
|
||||
assert list(schema['properties']['nested']['properties'].keys()) == ['number']
|
||||
assert schema['properties']['nested']['required'] == ['number']
|
||||
|
||||
def test_operation_id_generation(self):
|
||||
path = '/'
|
||||
method = 'GET'
|
||||
|
||||
view = create_view(
|
||||
views.ExampleGenericAPIView,
|
||||
method,
|
||||
create_request(path),
|
||||
)
|
||||
inspector = AutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
operationId = inspector._get_operation_id(path, method)
|
||||
assert operationId == 'ListExamples'
|
||||
|
||||
def test_repeat_operation_ids(self):
|
||||
router = routers.SimpleRouter()
|
||||
router.register('account', views.ExampleGenericViewSet, basename="account")
|
||||
urlpatterns = router.urls
|
||||
|
||||
generator = SchemaGenerator(patterns=urlpatterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
schema_str = str(schema)
|
||||
print(schema_str)
|
||||
assert schema_str.count("operationId") == 2
|
||||
assert schema_str.count("newExample") == 1
|
||||
assert schema_str.count("oldExample") == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'})
|
||||
class TestGenerator(TestCase):
|
||||
|
||||
def test_override_settings(self):
|
||||
assert isinstance(views.ExampleListView.schema, AutoSchema)
|
||||
|
||||
def test_paths_construction(self):
|
||||
"""Construction of the `paths` key."""
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
]
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
generator._initialise_endpoints()
|
||||
|
||||
paths = generator.get_paths()
|
||||
|
||||
assert '/example/' in paths
|
||||
example_operations = paths['/example/']
|
||||
assert len(example_operations) == 2
|
||||
assert 'get' in example_operations
|
||||
assert 'post' in example_operations
|
||||
|
||||
def test_schema_construction(self):
|
||||
"""Construction of the top level dictionary."""
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
]
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
|
||||
assert 'openapi' in schema
|
||||
assert 'paths' in schema
|
||||
|
||||
def test_serializer_datefield(self):
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleGenericViewSet.as_view({"get": "get"})),
|
||||
]
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
|
||||
response = schema['paths']['/example/']['get']['responses']
|
||||
response_schema = response['200']['content']['application/json']['schema']['properties']
|
||||
|
||||
assert response_schema['date']['type'] == response_schema['datetime']['type'] == 'string'
|
||||
|
||||
assert response_schema['date']['format'] == 'date'
|
||||
assert response_schema['datetime']['format'] == 'date-time'
|
58
tests/schemas/views.py
Normal file
58
tests/schemas/views.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
from rest_framework import generics, permissions, serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Generics.
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
date = serializers.DateField()
|
||||
datetime = serializers.DateTimeField()
|
||||
|
||||
|
||||
class ExampleGenericAPIView(generics.GenericAPIView):
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class ExampleGenericViewSet(GenericViewSet):
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
|
||||
serializer = self.get_serializer(data=now.date(), datetime=now)
|
||||
return Response(serializer.data)
|
||||
|
||||
@action(detail=False)
|
||||
def new(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@action(detail=False)
|
||||
def old(self, *args, **kwargs):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user