mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-30 01:49:50 +03:00
Added OpenAPI Schema Generation.
This commit is contained in:
parent
d2d1888217
commit
c0a31ed0a3
|
@ -40,6 +40,9 @@ class BaseFilterBackend(object):
|
|||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
|
@ -159,6 +162,19 @@ class SearchFilter(BaseFilterBackend):
|
|||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.search_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.search_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the ordering.
|
||||
|
@ -290,6 +306,19 @@ class OrderingFilter(BaseFilterBackend):
|
|||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.ordering_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.ordering_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class DjangoObjectPermissionsFilter(BaseFilterBackend):
|
||||
"""
|
||||
|
|
|
@ -1,25 +1,21 @@
|
|||
from django.core.management.base import BaseCommand
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.renderers import (
|
||||
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer
|
||||
)
|
||||
from rest_framework.schemas.generators import SchemaGenerator
|
||||
from rest_framework.compat import yaml
|
||||
from rest_framework.schemas.generators import OpenAPISchemaGenerator
|
||||
from rest_framework.utils import json
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generates configured API schema for project."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--title', dest="title", default=None, type=str)
|
||||
parser.add_argument('--title', dest="title", default='', type=str)
|
||||
parser.add_argument('--url', dest="url", default=None, type=str)
|
||||
parser.add_argument('--description', dest="description", default=None, type=str)
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
|
||||
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
assert coreapi is not None, 'coreapi must be installed.'
|
||||
|
||||
generator = SchemaGenerator(
|
||||
generator = OpenAPISchemaGenerator(
|
||||
url=options['url'],
|
||||
title=options['title'],
|
||||
description=options['description']
|
||||
|
@ -27,15 +23,10 @@ class Command(BaseCommand):
|
|||
|
||||
schema = generator.get_schema(request=None, public=True)
|
||||
|
||||
renderer = self.get_renderer(options['format'])
|
||||
output = renderer.render(schema, renderer_context={})
|
||||
self.stdout.write(output.decode('utf-8'))
|
||||
# TODO: Handle via renderer? More options?
|
||||
if options['format'] == 'openapi':
|
||||
output = yaml.dump(schema, default_flow_style=False)
|
||||
else:
|
||||
output = json.dumps(schema, indent=2)
|
||||
|
||||
def get_renderer(self, format):
|
||||
renderer_cls = {
|
||||
'corejson': CoreJSONRenderer,
|
||||
'openapi': OpenAPIRenderer,
|
||||
'openapi-json': JSONOpenAPIRenderer,
|
||||
}[format]
|
||||
|
||||
return renderer_cls()
|
||||
self.stdout.write(output)
|
||||
|
|
|
@ -152,6 +152,9 @@ class BasePagination(object):
|
|||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return []
|
||||
|
||||
|
||||
class PageNumberPagination(BasePagination):
|
||||
"""
|
||||
|
@ -305,6 +308,32 @@ class PageNumberPagination(BasePagination):
|
|||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.page_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
class LimitOffsetPagination(BasePagination):
|
||||
"""
|
||||
|
@ -434,6 +463,15 @@ class LimitOffsetPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template.render(context)
|
||||
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
Determine an object count, supporting either querysets or regular lists.
|
||||
"""
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
|
@ -458,14 +496,28 @@ class LimitOffsetPagination(BasePagination):
|
|||
)
|
||||
]
|
||||
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
Determine an object count, supporting either querysets or regular lists.
|
||||
"""
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
return len(queryset)
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.limit_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.limit_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': self.offset_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.offset_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
},
|
||||
]
|
||||
return parameters
|
||||
|
||||
|
||||
class CursorPagination(BasePagination):
|
||||
|
@ -820,3 +872,29 @@ class CursorPagination(BasePagination):
|
|||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
parameters = [
|
||||
{
|
||||
'name': self.cursor_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.cursor_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
}
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
parameters.append(
|
||||
{
|
||||
'name': self.page_size_query_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_text(self.page_size_query_description),
|
||||
'schema': {
|
||||
'type': 'integer',
|
||||
},
|
||||
}
|
||||
)
|
||||
return parameters
|
||||
|
|
|
@ -193,6 +193,10 @@ class EndpointEnumerator(object):
|
|||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
# ???: Would it be feasible to adjust this such that we generate the
|
||||
# path, plus the kwargs, plus the type from the convertor, such that we
|
||||
# could feed that straight into the parameter schema object?
|
||||
|
||||
path = simplify_regex(path_regex)
|
||||
|
||||
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
|
||||
|
@ -232,35 +236,18 @@ class EndpointEnumerator(object):
|
|||
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
class BaseSchemaGenerator(object):
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
|
@ -270,36 +257,15 @@ class SchemaGenerator(object):
|
|||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
def _initialise_endpoints(self):
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
def get_links(self, request=None):
|
||||
def _get_paths_and_endpoints(self, request):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
Generate (path, method, view) given (path, method, callback) for paths.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
|
@ -308,22 +274,48 @@ class SchemaGenerator(object):
|
|||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
return paths, view_endpoints
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
return links
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
|
@ -356,29 +348,6 @@ class SchemaGenerator(object):
|
|||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
|
@ -392,23 +361,77 @@ class SchemaGenerator(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
"""
|
||||
Original CoreAPI version.
|
||||
"""
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
links = LinkNode()
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
return links
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
|
@ -452,3 +475,55 @@ class SchemaGenerator(object):
|
|||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
|
||||
class OpenAPISchemaGenerator(BaseSchemaGenerator):
|
||||
|
||||
def get_info(self):
|
||||
info = {
|
||||
'title': self.title,
|
||||
'version': 'TODO',
|
||||
}
|
||||
|
||||
if self.description is not None:
|
||||
info['description'] = self.description
|
||||
|
||||
return info
|
||||
|
||||
def get_paths(self, request=None):
|
||||
result = {}
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
operation = view.schema.get_operation(path, method)
|
||||
subpath = '/' + path[len(prefix):]
|
||||
result.setdefault(subpath, {})
|
||||
result[subpath][method.lower()] = operation
|
||||
|
||||
return result
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a OpenAPI schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
paths = self.get_paths(None if public else request)
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
schema = {
|
||||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
}
|
||||
|
||||
return schema
|
||||
|
|
|
@ -179,20 +179,6 @@ class ViewInspector(object):
|
|||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
|
@ -213,6 +199,17 @@ class AutoSchema(ViewInspector):
|
|||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
|
@ -509,3 +506,246 @@ class DefaultSchema(ViewInspector):
|
|||
inspector = inspector_class()
|
||||
inspector.view = instance
|
||||
return inspector
|
||||
|
||||
|
||||
class OpenAPIAutoSchema(ViewInspector):
|
||||
|
||||
content_types = ['application/json']
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
parameters = []
|
||||
parameters += self._get_path_parameters(path, method)
|
||||
parameters += self._get_pagination_parameters(path, method)
|
||||
parameters += self._get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
|
||||
request_body = self._get_request_body(path, method)
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
def _get_path_parameters(self, path, method):
|
||||
"""
|
||||
Return a list of parameters from templated path variables.
|
||||
"""
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
parameters = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
if model is not None: # TODO: test this.
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
'schema': {
|
||||
'type': 'string', # TODO: integer, pattern, ...
|
||||
},
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def _get_filter_parameters(self, path, method):
|
||||
if not self._allows_filters(path, method):
|
||||
return []
|
||||
parameters = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
parameters += filter_backend().get_schema_operation_parameters(self.view)
|
||||
return parameters
|
||||
|
||||
def _allows_filters(self, path, method):
|
||||
"""
|
||||
Determine whether to include filter Fields in schema.
|
||||
|
||||
Default implementation looks for ModelViewSet or GenericAPIView
|
||||
actions/methods that cause filtering on the default implementation.
|
||||
"""
|
||||
if getattr(self.view, 'filter_backends', None) is None:
|
||||
return False
|
||||
if hasattr(self.view, 'action'):
|
||||
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
|
||||
return method.lower() in ["get", "put", "patch", "delete"]
|
||||
|
||||
def _get_pagination_parameters(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_operation_parameters(view)
|
||||
|
||||
def _map_field(self, field):
|
||||
|
||||
# Nested Serializers, `many` or not.
|
||||
if isinstance(field, serializers.ListSerializer):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._map_serializer(field.child)
|
||||
}
|
||||
if isinstance(field, serializers.Serializer):
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': self._map_serializer(field)
|
||||
}
|
||||
|
||||
# Related fields.
|
||||
if isinstance(field, serializers.ManyRelatedField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._map_field(field.child_relation)
|
||||
}
|
||||
if isinstance(field, serializers.PrimaryKeyRelatedField):
|
||||
model = getattr(field.queryset, 'model', None)
|
||||
if model is not None:
|
||||
model_field = model._meta.pk
|
||||
if isinstance(model_field, models.AutoField):
|
||||
return {'type': 'integer'}
|
||||
|
||||
# ChoiceFields (single and multiple).
|
||||
# Q:
|
||||
# - Is 'type' required?
|
||||
# - can we determine the TYPE of a choicefield?
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'enum': list(field.choices)
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
return {
|
||||
'enum': list(field.choices),
|
||||
}
|
||||
|
||||
# ListField.
|
||||
if isinstance(field, serializers.ListField):
|
||||
return {
|
||||
'type': 'array',
|
||||
}
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.DateField: 'date',
|
||||
serializers.DateTimeField: 'date-time',
|
||||
serializers.JSONField: 'object',
|
||||
serializers.DictField: 'object',
|
||||
}
|
||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||
|
||||
def _map_serializer(self, serializer):
|
||||
# Assuming we have a valid serializer instance.
|
||||
# TODO:
|
||||
# - field is Nested or List serializer.
|
||||
# - Handle read_only/write_only for request/response differences.
|
||||
# - could do this with readOnly/writeOnly and then filter dict.
|
||||
required = []
|
||||
properties = {}
|
||||
|
||||
for field in serializer.fields.values():
|
||||
if isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
if field.required:
|
||||
required.append(field.field_name)
|
||||
|
||||
schema = self._map_field(field)
|
||||
if field.read_only:
|
||||
schema['readOnly'] = True
|
||||
if field.write_only:
|
||||
schema['writeOnly'] = True
|
||||
if field.allow_null:
|
||||
schema['nullable'] = True
|
||||
|
||||
properties[field.field_name] = schema
|
||||
return {
|
||||
'required': required,
|
||||
'properties': properties,
|
||||
}
|
||||
|
||||
def _get_request_body(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return {}
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return {}
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return {}
|
||||
|
||||
content = self._map_serializer(serializer)
|
||||
# No required fields for PATCH
|
||||
if method == 'PATCH':
|
||||
del content['required']
|
||||
# No read_only fields for request.
|
||||
for name, schema in content['properties'].items():
|
||||
if 'readOnly' in schema:
|
||||
del content['properties']['name']
|
||||
|
||||
return {
|
||||
'content': {ct: content for ct in self.content_types}
|
||||
}
|
||||
|
||||
def _get_responses(self, path, method):
|
||||
# TODO: Handle multiple codes.
|
||||
content = {}
|
||||
view = self.view
|
||||
if hasattr(view, 'get_serializer'):
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
except exceptions.APIException:
|
||||
serializer = None
|
||||
warnings.warn('{}.get_serializer() raised an exception during '
|
||||
'schema generation. Serializer fields will not be '
|
||||
'generated for {} {}.'
|
||||
.format(view.__class__.__name__, method, path))
|
||||
|
||||
if isinstance(serializer, serializers.Serializer):
|
||||
content = self._map_serializer(serializer)
|
||||
# No write_only fields for response.
|
||||
for name, schema in content['properties'].items():
|
||||
if 'writeOnly' in schema:
|
||||
del content['properties']['name']
|
||||
|
||||
return {
|
||||
'200': {
|
||||
'content': {ct: content for ct in self.content_types}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ DEFAULTS = {
|
|||
'DEFAULT_FILTER_BACKENDS': (),
|
||||
|
||||
# Schema
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema',
|
||||
|
||||
# Throttling
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
|
|
0
tests/schemas/__init__.py
Normal file
0
tests/schemas/__init__.py
Normal file
|
@ -24,7 +24,8 @@ from rest_framework.utils import formatting
|
|||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from .models import BasicModel, ForeignKeySource, ManyToManySource
|
||||
from . import views
|
||||
from ..models import BasicModel, ForeignKeySource, ManyToManySource
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
@ -148,7 +149,7 @@ urlpatterns = [
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_schemas')
|
||||
@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestRouterGeneratedSchema(TestCase):
|
||||
def test_anonymous_request(self):
|
||||
client = APIClient()
|
||||
|
@ -382,30 +383,14 @@ class MethodLimitedViewSet(ExampleViewSet):
|
|||
http_method_names = ['get', 'head', 'options']
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGenerator(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^example/?$', ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -453,12 +438,13 @@ class TestSchemaGenerator(TestCase):
|
|||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@unittest.skipUnless(path, 'needs Django 2')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorDjango2(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
path('example/', ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', ExampleDetailView.as_view()),
|
||||
path('example/', views.ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', views.ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -505,12 +491,13 @@ class TestSchemaGeneratorDjango2(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorNotAtRoot(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^api/v1/example/?$', ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -558,6 +545,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
|
@ -622,13 +610,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
router.register('example1', Http404ExampleViewSet, basename='example1')
|
||||
router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
|
||||
self.patterns = [
|
||||
url('^example/?$', ExampleListView.as_view()),
|
||||
url('^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^', include(router.urls))
|
||||
]
|
||||
|
||||
|
@ -668,6 +657,7 @@ class ForeignKeySourceView(generics.CreateAPIView):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithForeignKey(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
|
@ -713,6 +703,7 @@ class ManyToManySourceView(generics.CreateAPIView):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestSchemaGeneratorWithManyToMany(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
|
@ -747,6 +738,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class Test4605Regression(TestCase):
|
||||
def test_4605_regression(self):
|
||||
generator = SchemaGenerator()
|
||||
|
@ -762,6 +754,7 @@ class CustomViewInspector(AutoSchema):
|
|||
pass
|
||||
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestAutoSchema(TestCase):
|
||||
|
||||
def test_apiview_schema_descriptor(self):
|
||||
|
@ -777,7 +770,7 @@ class TestAutoSchema(TestCase):
|
|||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_set_custom_inspector_class_via_settings(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}):
|
||||
view = APIView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
|
@ -971,6 +964,7 @@ class TestAutoSchema(TestCase):
|
|||
self.assertEqual(field_to_schema(case[0]), case[1])
|
||||
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
def test_docstring_is_not_stripped_by_get_description():
|
||||
class ExampleDocstringAPIView(APIView):
|
||||
"""
|
||||
|
@ -1014,20 +1008,19 @@ class ExcludedAPIView(APIView):
|
|||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class SchemaGenerationExclusionTests(TestCase):
|
||||
def setUp(self):
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
self.patterns = [
|
||||
url('^excluded-cbv/$', ExcludedAPIView.as_view()),
|
||||
url('^excluded-fbv/$', excluded_fbv),
|
||||
|
@ -1078,11 +1071,6 @@ class SchemaGenerationExclusionTests(TestCase):
|
|||
assert should_include == expected
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
|
@ -1118,11 +1106,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
|
|||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestURLNamingCollisions(TestCase):
|
||||
"""
|
||||
Ref: https://github.com/encode/django-rest-framework/issues/4704
|
||||
"""
|
||||
def test_manually_routing_nested_routes(self):
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url(r'^test', simple_fbv),
|
||||
url(r'^test/list/', simple_fbv),
|
||||
|
@ -1228,6 +1221,10 @@ class TestURLNamingCollisions(TestCase):
|
|||
|
||||
def test_url_under_same_key_not_replaced_another(self):
|
||||
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url(r'^test/list/', simple_fbv),
|
||||
url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
|
||||
|
@ -1303,7 +1300,8 @@ def test_head_and_options_methods_are_excluded():
|
|||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
class TestAutoSchemaAllowsFilters(object):
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
|
||||
class TestAutoSchemaAllowsFilters(TestCase):
|
||||
class MockAPIView(APIView):
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
|
118
tests/schemas/test_openapi.py
Normal file
118
tests/schemas/test_openapi.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
from django.conf.urls import url
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
|
||||
from rest_framework import filters, pagination
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.schemas.generators import OpenAPISchemaGenerator
|
||||
from rest_framework.schemas.inspectors import OpenAPIAutoSchema
|
||||
|
||||
from . import views
|
||||
|
||||
|
||||
def create_request(path):
|
||||
factory = RequestFactory()
|
||||
request = Request(factory.get(path))
|
||||
return request
|
||||
|
||||
|
||||
def create_view(view_cls, method, request):
|
||||
generator = OpenAPISchemaGenerator()
|
||||
view = generator.create_view(view_cls.as_view(), method, request)
|
||||
return view
|
||||
|
||||
|
||||
class TestBasics(TestCase):
|
||||
def dummy_view(request):
|
||||
pass
|
||||
|
||||
def test_filters(self):
|
||||
classes = [filters.SearchFilter, filters.OrderingFilter]
|
||||
for c in classes:
|
||||
f = c()
|
||||
assert f.get_schema_operation_parameters(self.dummy_view)
|
||||
|
||||
def test_pagination(self):
|
||||
classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination]
|
||||
for c in classes:
|
||||
f = c()
|
||||
assert f.get_schema_operation_parameters(self.dummy_view)
|
||||
|
||||
|
||||
class TestOperationIntrospection(TestCase):
|
||||
|
||||
def test_path_without_parameters(self):
|
||||
path = '/example/'
|
||||
method = 'GET'
|
||||
|
||||
view = create_view(
|
||||
views.ExampleListView,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = OpenAPIAutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
operation = inspector.get_operation(path, method)
|
||||
assert operation == {
|
||||
'parameters': [],
|
||||
'responses': {'200': {'content': {'application/json': {}}}},
|
||||
}
|
||||
|
||||
def test_path_with_id_parameter(self):
|
||||
path = '/example/{id}/'
|
||||
method = 'GET'
|
||||
|
||||
view = create_view(
|
||||
views.ExampleDetailView,
|
||||
method,
|
||||
create_request(path)
|
||||
)
|
||||
inspector = OpenAPIAutoSchema()
|
||||
inspector.view = view
|
||||
|
||||
parameters = inspector._get_path_parameters(path, method)
|
||||
assert parameters == [{
|
||||
'description': '',
|
||||
'in': 'path',
|
||||
'name': 'id',
|
||||
'required': True,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
}]
|
||||
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema'})
|
||||
class TestGenerator(TestCase):
|
||||
|
||||
def test_override_settings(self):
|
||||
assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema)
|
||||
|
||||
def test_paths_construction(self):
|
||||
"""Construction of the `paths` key."""
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
]
|
||||
generator = OpenAPISchemaGenerator(patterns=patterns)
|
||||
generator._initialise_endpoints()
|
||||
|
||||
paths = generator.get_paths()
|
||||
|
||||
assert '/example/' in paths
|
||||
example_operations = paths['/example/']
|
||||
assert len(example_operations) == 2
|
||||
assert 'get' in example_operations
|
||||
assert 'post' in example_operations
|
||||
|
||||
def test_schema_construction(self):
|
||||
"""Construction of the top level dictionary."""
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
]
|
||||
generator = OpenAPISchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
|
||||
assert 'openapi' in schema
|
||||
assert 'paths' in schema
|
19
tests/schemas/views.py
Normal file
19
tests/schemas/views.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from rest_framework import permissions
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
|
@ -7,8 +7,8 @@ from django.test import TestCase
|
|||
from django.test.utils import override_settings
|
||||
from django.utils import six
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.utils import formatting, json
|
||||
from rest_framework.compat import coreapi, yaml
|
||||
from rest_framework.utils import json
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
|
@ -31,58 +31,21 @@ class GenerateSchemaTests(TestCase):
|
|||
self.out = six.StringIO()
|
||||
|
||||
@pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.')
|
||||
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
|
||||
def test_renders_default_schema_with_custom_title_url_and_description(self):
|
||||
expected_out = """info:
|
||||
description: Sample description
|
||||
title: SampleAPI
|
||||
version: ''
|
||||
openapi: 3.0.0
|
||||
paths:
|
||||
/:
|
||||
get:
|
||||
operationId: list
|
||||
servers:
|
||||
- url: http://api.sample.com/
|
||||
"""
|
||||
call_command('generateschema',
|
||||
'--title=SampleAPI',
|
||||
'--url=http://api.sample.com',
|
||||
'--description=Sample description',
|
||||
stdout=self.out)
|
||||
|
||||
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
|
||||
# Check valid YAML was output.
|
||||
schema = yaml.load(self.out.getvalue())
|
||||
assert schema['openapi'] == '3.0.2'
|
||||
|
||||
def test_renders_openapi_json_schema(self):
|
||||
expected_out = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"version": "",
|
||||
"title": "",
|
||||
"description": ""
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
"url": ""
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/": {
|
||||
"get": {
|
||||
"operationId": "list"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
call_command('generateschema',
|
||||
'--format=openapi-json',
|
||||
stdout=self.out)
|
||||
# Check valid JSON was output.
|
||||
out_json = json.loads(self.out.getvalue())
|
||||
|
||||
self.assertDictEqual(out_json, expected_out)
|
||||
|
||||
def test_renders_corejson_schema(self):
|
||||
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
|
||||
call_command('generateschema',
|
||||
'--format=corejson',
|
||||
stdout=self.out)
|
||||
self.assertIn(expected_out, self.out.getvalue())
|
||||
assert out_json['openapi'] == '3.0.2'
|
||||
|
|
Loading…
Reference in New Issue
Block a user