Added OpenAPI Schema Generation. (#6532)

Co-authored-by: Lucidiot <lucidiot@protonmail.com>
Co-authored-by: dongfangtianyu <dongfangtianyu@qq.com>
This commit is contained in:
Carlton Gibson 2019-05-13 16:07:03 +02:00 committed by GitHub
parent a91e6a0e69
commit 37f210a455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1671 additions and 734 deletions

View File

@ -37,6 +37,9 @@ class BaseFilterBackend:
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return []
def get_schema_operation_parameters(self, view):
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
@ -156,6 +159,19 @@ class SearchFilter(BaseFilterBackend):
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_text(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
@ -287,6 +303,19 @@ class OrderingFilter(BaseFilterBackend):
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_text(self.ordering_description),
'schema': {
'type': 'string',
},
},
]
class DjangoObjectPermissionsFilter(BaseFilterBackend):
"""

View File

@ -1,41 +1,56 @@
from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi
from rest_framework.renderers import (
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer
)
from rest_framework.schemas.generators import SchemaGenerator
from rest_framework import renderers
from rest_framework.schemas import coreapi
from rest_framework.schemas.openapi import SchemaGenerator
OPENAPI_MODE = 'openapi'
COREAPI_MODE = 'coreapi'
class Command(BaseCommand):
help = "Generates configured API schema for project."
def get_mode(self):
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default=None, type=str)
parser.add_argument('--title', dest="title", default='', type=str)
parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str)
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
if self.get_mode() == COREAPI_MODE:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
else:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.'
generator = SchemaGenerator(
generator_class = self.get_generator_class()
generator = generator_class(
url=options['url'],
title=options['title'],
description=options['description']
)
schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format'])
output = renderer.render(schema, renderer_context={})
self.stdout.write(output.decode())
def get_renderer(self, format):
renderer_cls = {
'corejson': CoreJSONRenderer,
'openapi': OpenAPIRenderer,
'openapi-json': JSONOpenAPIRenderer,
}[format]
if self.get_mode() == COREAPI_MODE:
renderer_cls = {
'corejson': renderers.CoreJSONRenderer,
'openapi': renderers.CoreAPIOpenAPIRenderer,
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
}[format]
return renderer_cls()
renderer_cls = {
'openapi': renderers.OpenAPIRenderer,
'openapi-json': renderers.JSONOpenAPIRenderer,
}[format]
return renderer_cls()
def get_generator_class(self):
if self.get_mode() == COREAPI_MODE:
return coreapi.SchemaGenerator
return SchemaGenerator

View File

@ -148,6 +148,9 @@ class BasePagination:
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
return []
def get_schema_operation_parameters(self, view):
return []
class PageNumberPagination(BasePagination):
"""
@ -301,6 +304,32 @@ class PageNumberPagination(BasePagination):
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.page_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_query_description),
'schema': {
'type': 'integer',
},
},
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
},
)
return parameters
class LimitOffsetPagination(BasePagination):
"""
@ -430,6 +459,15 @@ class LimitOffsetPagination(BasePagination):
context = self.get_html_context()
return template.render(context)
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
@ -454,14 +492,28 @@ class LimitOffsetPagination(BasePagination):
)
]
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.limit_query_param,
'required': False,
'in': 'query',
'description': force_text(self.limit_query_description),
'schema': {
'type': 'integer',
},
},
{
'name': self.offset_query_param,
'required': False,
'in': 'query',
'description': force_text(self.offset_query_description),
'schema': {
'type': 'integer',
},
},
]
return parameters
class CursorPagination(BasePagination):
@ -816,3 +868,29 @@ class CursorPagination(BasePagination):
)
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.cursor_query_param,
'required': False,
'in': 'query',
'description': force_text(self.cursor_query_description),
'schema': {
'type': 'integer',
},
}
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
}
)
return parameters

View File

@ -1013,28 +1013,49 @@ class _BaseOpenAPIRenderer:
}
class OpenAPIRenderer(_BaseOpenAPIRenderer):
class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer):
media_type = 'application/vnd.oai.openapi'
charset = None
format = 'openapi'
def __init__(self):
assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.'
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.'
assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.'
def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data)
return yaml.dump(structure, default_flow_style=False).encode()
class JSONOpenAPIRenderer(_BaseOpenAPIRenderer):
class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer):
media_type = 'application/vnd.oai.openapi+json'
charset = None
format = 'openapi-json'
def __init__(self):
assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.'
assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.'
def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data)
return json.dumps(structure, indent=4).encode()
return json.dumps(structure, indent=4).encode('utf-8')
class OpenAPIRenderer(BaseRenderer):
media_type = 'application/vnd.oai.openapi'
charset = None
format = 'openapi'
def __init__(self):
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
def render(self, data, media_type=None, renderer_context=None):
return yaml.dump(data, default_flow_style=False).encode('utf-8')
class JSONOpenAPIRenderer(BaseRenderer):
media_type = 'application/vnd.oai.openapi+json'
charset = None
format = 'openapi-json'
def render(self, data, media_type=None, renderer_context=None):
return json.dumps(data, indent=2).encode('utf-8')

View File

@ -22,24 +22,32 @@ Other access should target the submodules directly
"""
from rest_framework.settings import api_settings
from .generators import SchemaGenerator
from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa
from . import coreapi, openapi
from .inspectors import DefaultSchema # noqa
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator,
public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
"""
Return a schema view.
"""
# Avoid import cycle on APIView
from .views import SchemaView
if generator_class is None:
if coreapi.is_enabled():
generator_class = coreapi.SchemaGenerator
else:
generator_class = openapi.SchemaGenerator
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
# Avoid import cycle on APIView
from .views import SchemaView
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,

View File

@ -0,0 +1,616 @@
import re
import warnings
from collections import Counter, OrderedDict
from urllib import parse
from django.db import models
from django.utils.encoding import force_text, smart_text
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
# TODO: Pull some of this into base.
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super(LinkNode, self).__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = LinkNode()
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
# View Inspectors #
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super(AutoSchema, self).__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
def is_enabled():
"""Is CoreAPI Mode enabled?"""
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@ -4,7 +4,6 @@ generators.py # Top-down schema generation
See schemas.__init__.py for package overview.
"""
import re
from collections import Counter, OrderedDict
from importlib import import_module
from django.conf import settings
@ -13,15 +12,11 @@ from django.core.exceptions import PermissionDenied
from django.http import Http404
from rest_framework import exceptions
from rest_framework.compat import (
URLPattern, URLResolver, coreapi, coreschema, get_original_route
)
from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk
from .utils import is_list_view
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
@ -50,78 +45,6 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView)
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super().__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
@ -190,6 +113,10 @@ class EndpointEnumerator:
"""
Given a URL conf regex, return a URI template string.
"""
# ???: Would it be feasible to adjust this such that we generate the
# path, plus the kwargs, plus the type from the convertor, such that we
# could feed that straight into the parameter schema object?
path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
@ -228,35 +155,18 @@ class EndpointEnumerator:
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class SchemaGenerator:
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
class BaseSchemaGenerator(object):
endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'):
url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
@ -266,36 +176,15 @@ class SchemaGenerator:
self.url = url
self.endpoints = None
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
def _initialise_endpoints(self):
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
def _get_paths_and_endpoints(self, request):
"""
Return a dictionary containing all the links that should be
included in the API schema.
Generate (path, method, view) given (path, method, callback) for paths.
"""
links = LinkNode()
# Generate (path, method, view) given (path, method, callback).
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
@ -304,22 +193,48 @@ class SchemaGenerator:
paths.append(path)
view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
return paths, view_endpoints
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
return links
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
# Methods used when we generate a view instance from the raw callback...
if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def determine_path_prefix(self, paths):
"""
@ -352,29 +267,6 @@ class SchemaGenerator:
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
@ -387,64 +279,3 @@ class SchemaGenerator:
except (exceptions.APIException, Http404, PermissionDenied):
return False
return True
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

View File

@ -3,125 +3,9 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
import warnings
from collections import OrderedDict
from urllib import parse
from weakref import WeakKeyDictionary
from django.db import models
from django.utils.encoding import force_text, smart_text
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class ViewInspector:
@ -178,320 +62,6 @@ class ViewInspector:
def view(self):
self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super().__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
return list(by_name.values())
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super().__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -0,0 +1,377 @@
import warnings
from django.db import models
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers
from rest_framework.compat import uritemplate
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Generator
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
info = {
'title': self.title,
'version': 'TODO',
}
if self.description is not None:
info['description'] = self.description
return info
def get_paths(self, request=None):
result = {}
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
subpath = '/' + path[len(prefix):]
result.setdefault(subpath, {})
result[subpath][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
paths = self.get_paths(None if public else request)
if not paths:
return None
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
return schema
# View Inspectors
class AutoSchema(ViewInspector):
content_types = ['application/json']
method_mapping = {
'get': 'Retrieve',
'post': 'Create',
'put': 'Update',
'patch': 'PartialUpdate',
'delete': 'Destroy',
}
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self._get_operation_id(path, method)
parameters = []
parameters += self._get_path_parameters(path, method)
parameters += self._get_pagination_parameters(path, method)
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self._get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
return operation
def _get_operation_id(self, path, method):
"""
Compute an operation ID from the model, serializer or view name.
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'List'
elif method_name not in self.method_mapping:
action = method_name
else:
action = self.method_mapping[method.lower()]
# Try to deduce the ID from the view's model
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if model is not None:
name = model.__name__
# Try with the serializer class name
elif hasattr(self.view, 'get_serializer_class'):
name = self.view.get_serializer_class().__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
if name.endswith(action): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'List' and not name.endswith('s'): # ListThings instead of ListThing
name += 's'
return action + name
def _get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
}
parameters.append(parameter)
return parameters
def _get_filter_parameters(self, path, method):
if not self._allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def _get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_operation_parameters(view)
def _map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self._map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
data = self._map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self._map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': {
'enum': list(field.choices)
},
}
if isinstance(field, serializers.ChoiceField):
return {
'enum': list(field.choices),
}
# ListField.
if isinstance(field, serializers.ListField):
return {
'type': 'array',
}
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
}
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.DecimalField: 'number',
serializers.FloatField: 'number',
serializers.IntegerField: 'integer',
serializers.JSONField: 'object',
serializers.DictField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
# TODO:
# - field is Nested or List serializer.
# - Handle read_only/write_only for request/response differences.
# - could do this with readOnly/writeOnly and then filter dict.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self._map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
properties[field.field_name] = schema
return {
'required': required,
'properties': properties,
}
def _get_request_body(self, path, method):
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return {}
if not hasattr(view, 'get_serializer'):
return {}
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
del content['required']
# No read_only fields for request.
for name, schema in content['properties'].copy().items():
if 'readOnly' in schema:
del content['properties'][name]
return {
'content': {
ct: {'schema': content}
for ct in self.content_types
}
}
def _get_responses(self, path, method):
# TODO: Handle multiple codes.
content = {}
view = self.view
if hasattr(view, 'get_serializer'):
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.Serializer):
content = self._map_serializer(serializer)
# No write_only fields for response.
for name, schema in content['properties'].copy().items():
if 'writeOnly' in schema:
del content['properties'][name]
content['required'] = [f for f in content['required'] if f != name]
return {
'200': {
'content': {
ct: {'schema': content}
for ct in self.content_types
}
}
}

View File

@ -3,6 +3,9 @@ utils.py # Shared helper functions
See schemas.__init__.py for package overview.
"""
from django.db import models
from django.utils.translation import ugettext_lazy as _
from rest_framework.mixins import RetrieveModelMixin
@ -22,3 +25,17 @@ def is_list_view(path, method, view):
if path_components and '{' in path_components[-1]:
return False
return True
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)

View File

@ -5,6 +5,7 @@ See schemas.__init__.py for package overview.
"""
from rest_framework import exceptions, renderers
from rest_framework.response import Response
from rest_framework.schemas import coreapi
from rest_framework.settings import api_settings
from rest_framework.views import APIView
@ -19,10 +20,16 @@ class SchemaView(APIView):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.renderer_classes is None:
self.renderer_classes = [
renderers.OpenAPIRenderer,
renderers.CoreJSONRenderer
]
if coreapi.is_enabled():
self.renderer_classes = [
renderers.CoreAPIOpenAPIRenderer,
renderers.CoreJSONRenderer
]
else:
self.renderer_classes = [
renderers.OpenAPIRenderer,
renderers.JSONOpenAPIRenderer,
]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer]

View File

@ -52,7 +52,7 @@ DEFAULTS = {
'DEFAULT_FILTER_BACKENDS': (),
# Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
# Throttling
'DEFAULT_THROTTLE_RATES': {

View File

View File

@ -16,15 +16,16 @@ from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
)
from rest_framework.schemas.coreapi import field_to_schema
from rest_framework.schemas.generators import EndpointEnumerator
from rest_framework.schemas.inspectors import field_to_schema
from rest_framework.schemas.utils import is_list_view
from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.utils import formatting
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import BasicModel, ForeignKeySource, ManyToManySource
from . import views
from ..models import BasicModel, ForeignKeySource, ManyToManySource
factory = APIRequestFactory()
@ -133,11 +134,12 @@ class ExampleViewSet(ModelViewSet):
pass
if coreapi:
schema_view = get_schema_view(title='Example API')
else:
def schema_view(request):
pass
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
if coreapi:
schema_view = get_schema_view(title='Example API')
else:
def schema_view(request):
pass
router = DefaultRouter()
router.register('example', ExampleViewSet, basename='example')
@ -148,7 +150,7 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas')
@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self):
client = APIClient()
@ -400,12 +402,13 @@ class ExampleDetailView(APIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGenerator(TestCase):
def setUp(self):
self.patterns = [
url(r'^example/?$', ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
url(r'^example/?$', views.ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -453,12 +456,13 @@ class TestSchemaGenerator(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@unittest.skipUnless(path, 'needs Django 2')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorDjango2(TestCase):
def setUp(self):
self.patterns = [
path('example/', ExampleListView.as_view()),
path('example/<int:pk>/', ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', ExampleDetailView.as_view()),
path('example/', views.ExampleListView.as_view()),
path('example/<int:pk>/', views.ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -505,12 +509,13 @@ class TestSchemaGeneratorDjango2(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorNotAtRoot(TestCase):
def setUp(self):
self.patterns = [
url(r'^api/v1/example/?$', ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -558,6 +563,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
@ -622,13 +628,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
router.register('example1', Http404ExampleViewSet, basename='example1')
router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
self.patterns = [
url('^example/?$', ExampleListView.as_view()),
url('^example/?$', views.ExampleListView.as_view()),
url(r'^', include(router.urls))
]
@ -668,6 +675,7 @@ class ForeignKeySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithForeignKey(TestCase):
def setUp(self):
self.patterns = [
@ -713,6 +721,7 @@ class ManyToManySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithManyToMany(TestCase):
def setUp(self):
self.patterns = [
@ -747,6 +756,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class Test4605Regression(TestCase):
def test_4605_regression(self):
generator = SchemaGenerator()
@ -762,6 +772,7 @@ class CustomViewInspector(AutoSchema):
pass
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestAutoSchema(TestCase):
def test_apiview_schema_descriptor(self):
@ -777,7 +788,7 @@ class TestAutoSchema(TestCase):
assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}):
view = APIView()
assert isinstance(view.schema, CustomViewInspector)
@ -971,6 +982,7 @@ class TestAutoSchema(TestCase):
self.assertEqual(field_to_schema(case[0]), case[1])
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_docstring_is_not_stripped_by_get_description():
class ExampleDocstringAPIView(APIView):
"""
@ -1007,25 +1019,25 @@ def test_docstring_is_not_stripped_by_get_description():
# Views for SchemaGenerationExclusionTests
class ExcludedAPIView(APIView):
schema = None
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
class ExcludedAPIView(APIView):
schema = None
def get(self, request, *args, **kwargs):
def get(self, request, *args, **kwargs):
pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class SchemaGenerationExclusionTests(TestCase):
def setUp(self):
self.patterns = [
@ -1078,11 +1090,6 @@ class SchemaGenerationExclusionTests(TestCase):
assert should_include == expected
@api_view(["GET"])
def simple_fbv(request):
pass
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
@ -1118,11 +1125,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestURLNamingCollisions(TestCase):
"""
Ref: https://github.com/encode/django-rest-framework/issues/4704
"""
def test_manually_routing_nested_routes(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [
url(r'^test', simple_fbv),
url(r'^test/list/', simple_fbv),
@ -1228,6 +1240,10 @@ class TestURLNamingCollisions(TestCase):
def test_url_under_same_key_not_replaced_another(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [
url(r'^test/list/', simple_fbv),
url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
@ -1302,10 +1318,8 @@ def test_head_and_options_methods_are_excluded():
assert inspector.get_allowed_methods(callback) == ["GET"]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestAutoSchemaAllowsFilters:
class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter]
class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter]
def _test(self, method):
view = self.MockAPIView()

View File

@ -0,0 +1,20 @@
import pytest
from django.test import TestCase, override_settings
from rest_framework import renderers
from rest_framework.schemas import coreapi, get_schema_view, openapi
class GetSchemaViewTests(TestCase):
"""For the get_schema_view() helper."""
def test_openapi(self):
schema_view = get_schema_view(title="With OpenAPI")
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes
@pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed')
def test_coreapi(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
schema_view = get_schema_view(title="With CoreAPI")
assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator)
assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes

View File

@ -6,7 +6,8 @@ from django.core.management import call_command
from django.test import TestCase
from django.test.utils import override_settings
from rest_framework.compat import coreapi
from rest_framework.compat import uritemplate, yaml
from rest_framework.management.commands import generateschema
from rest_framework.utils import formatting, json
from rest_framework.views import APIView
@ -21,15 +22,43 @@ urlpatterns = [
]
@override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@override_settings(ROOT_URLCONF=__name__)
@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed')
class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema."""
def setUp(self):
self.out = io.StringIO()
def test_command_detects_schema_generation_mode(self):
"""Switching between CoreAPI & OpenAPI"""
command = generateschema.Command()
assert command.get_mode() == generateschema.OPENAPI_MODE
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
assert command.get_mode() == generateschema.COREAPI_MODE
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
def test_renders_default_schema_with_custom_title_url_and_description(self):
call_command('generateschema',
'--title=SampleAPI',
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
# Check valid YAML was output.
schema = yaml.load(self.out.getvalue())
assert schema['openapi'] == '3.0.2'
def test_renders_openapi_json_schema(self):
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
# Check valid JSON was output.
out_json = json.loads(self.out.getvalue())
assert out_json['openapi'] == '3.0.2'
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info:
description: Sample description
title: SampleAPI
@ -50,7 +79,8 @@ class GenerateSchemaTests(TestCase):
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self):
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_coreapi_renders_openapi_json_schema(self):
expected_out = {
"openapi": "3.0.0",
"info": {
@ -78,6 +108,7 @@ class GenerateSchemaTests(TestCase):
self.assertDictEqual(out_json, expected_out)
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema',

View File

@ -0,0 +1,245 @@
import pytest
from django.conf.urls import url
from django.test import RequestFactory, TestCase, override_settings
from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.compat import uritemplate
from rest_framework.request import Request
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
from . import views
def create_request(path):
factory = RequestFactory()
request = Request(factory.get(path))
return request
def create_view(view_cls, method, request):
generator = SchemaGenerator()
view = generator.create_view(view_cls.as_view(), method, request)
return view
class TestBasics(TestCase):
def dummy_view(request):
pass
def test_filters(self):
classes = [filters.SearchFilter, filters.OrderingFilter]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
def test_pagination(self):
classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
class TestOperationIntrospection(TestCase):
def test_path_without_parameters(self):
path = '/example/'
method = 'GET'
view = create_view(
views.ExampleListView,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'ListExamples',
'parameters': [],
'responses': {'200': {'content': {'application/json': {'schema': {}}}}},
}
def test_path_with_id_parameter(self):
path = '/example/{id}/'
method = 'GET'
view = create_view(
views.ExampleDetailView,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
parameters = inspector._get_path_parameters(path, method)
assert parameters == [{
'description': '',
'in': 'path',
'name': 'id',
'required': True,
'schema': {
'type': 'string',
},
}]
def test_request_body(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
text = serializers.CharField()
read_only = serializers.CharField(read_only=True)
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
assert request_body['content']['application/json']['schema']['required'] == ['text']
assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text']
def test_response_body_generation(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
text = serializers.CharField()
write_only = serializers.CharField(write_only=True)
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
assert responses['200']['content']['application/json']['schema']['required'] == ['text']
assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text']
def test_response_body_nested_serializer(self):
path = '/'
method = 'POST'
class NestedSerializer(serializers.Serializer):
number = serializers.IntegerField()
class Serializer(serializers.Serializer):
text = serializers.CharField()
nested = NestedSerializer()
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
schema = responses['200']['content']['application/json']['schema']
assert sorted(schema['required']) == ['nested', 'text']
assert sorted(list(schema['properties'].keys())) == ['nested', 'text']
assert schema['properties']['nested']['type'] == 'object'
assert list(schema['properties']['nested']['properties'].keys()) == ['number']
assert schema['properties']['nested']['required'] == ['number']
def test_operation_id_generation(self):
path = '/'
method = 'GET'
view = create_view(
views.ExampleGenericAPIView,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
operationId = inspector._get_operation_id(path, method)
assert operationId == 'ListExamples'
def test_repeat_operation_ids(self):
router = routers.SimpleRouter()
router.register('account', views.ExampleGenericViewSet, basename="account")
urlpatterns = router.urls
generator = SchemaGenerator(patterns=urlpatterns)
request = create_request('/')
schema = generator.get_schema(request=request)
schema_str = str(schema)
print(schema_str)
assert schema_str.count("operationId") == 2
assert schema_str.count("newExample") == 1
assert schema_str.count("oldExample") == 1
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'})
class TestGenerator(TestCase):
def test_override_settings(self):
assert isinstance(views.ExampleListView.schema, AutoSchema)
def test_paths_construction(self):
"""Construction of the `paths` key."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()
paths = generator.get_paths()
assert '/example/' in paths
example_operations = paths['/example/']
assert len(example_operations) == 2
assert 'get' in example_operations
assert 'post' in example_operations
def test_schema_construction(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'openapi' in schema
assert 'paths' in schema
def test_serializer_datefield(self):
patterns = [
url(r'^example/?$', views.ExampleGenericViewSet.as_view({"get": "get"})),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
response = schema['paths']['/example/']['get']['responses']
response_schema = response['200']['content']['application/json']['schema']['properties']
assert response_schema['date']['type'] == response_schema['datetime']['type'] == 'string'
assert response_schema['date']['format'] == 'date'
assert response_schema['datetime']['format'] == 'date-time'

58
tests/schemas/views.py Normal file
View File

@ -0,0 +1,58 @@
from rest_framework import generics, permissions, serializers
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
# Generics.
class ExampleSerializer(serializers.Serializer):
date = serializers.DateField()
datetime = serializers.DateTimeField()
class ExampleGenericAPIView(generics.GenericAPIView):
serializer_class = ExampleSerializer
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
class ExampleGenericViewSet(GenericViewSet):
serializer_class = ExampleSerializer
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
@action(detail=False)
def new(self, *args, **kwargs):
pass
@action(detail=False)
def old(self, *args, **kwargs):
pass