Added OpenAPI Schema Generation.

This commit is contained in:
Carlton Gibson 2019-03-19 15:51:59 +01:00
parent d2d1888217
commit c0a31ed0a3
11 changed files with 738 additions and 227 deletions

View File

@ -40,6 +40,9 @@ class BaseFilterBackend(object):
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [] return []
def get_schema_operation_parameters(self, view):
return []
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search. # The URL query parameter used for the search.
@ -159,6 +162,19 @@ class SearchFilter(BaseFilterBackend):
) )
] ]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_text(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering. # The URL query parameter used for the ordering.
@ -290,6 +306,19 @@ class OrderingFilter(BaseFilterBackend):
) )
] ]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_text(self.ordering_description),
'schema': {
'type': 'string',
},
},
]
class DjangoObjectPermissionsFilter(BaseFilterBackend): class DjangoObjectPermissionsFilter(BaseFilterBackend):
""" """

View File

@ -1,25 +1,21 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi from rest_framework.compat import yaml
from rest_framework.renderers import ( from rest_framework.schemas.generators import OpenAPISchemaGenerator
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer from rest_framework.utils import json
)
from rest_framework.schemas.generators import SchemaGenerator
class Command(BaseCommand): class Command(BaseCommand):
help = "Generates configured API schema for project." help = "Generates configured API schema for project."
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default=None, type=str) parser.add_argument('--title', dest="title", default='', type=str)
parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str)
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
def handle(self, *args, **options): def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.' generator = OpenAPISchemaGenerator(
generator = SchemaGenerator(
url=options['url'], url=options['url'],
title=options['title'], title=options['title'],
description=options['description'] description=options['description']
@ -27,15 +23,10 @@ class Command(BaseCommand):
schema = generator.get_schema(request=None, public=True) schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format']) # TODO: Handle via renderer? More options?
output = renderer.render(schema, renderer_context={}) if options['format'] == 'openapi':
self.stdout.write(output.decode('utf-8')) output = yaml.dump(schema, default_flow_style=False)
else:
output = json.dumps(schema, indent=2)
def get_renderer(self, format): self.stdout.write(output)
renderer_cls = {
'corejson': CoreJSONRenderer,
'openapi': OpenAPIRenderer,
'openapi-json': JSONOpenAPIRenderer,
}[format]
return renderer_cls()

View File

@ -152,6 +152,9 @@ class BasePagination(object):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
return [] return []
def get_schema_operation_parameters(self, view):
return []
class PageNumberPagination(BasePagination): class PageNumberPagination(BasePagination):
""" """
@ -305,6 +308,32 @@ class PageNumberPagination(BasePagination):
) )
return fields return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.page_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_query_description),
'schema': {
'type': 'integer',
},
},
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
},
)
return parameters
class LimitOffsetPagination(BasePagination): class LimitOffsetPagination(BasePagination):
""" """
@ -434,6 +463,15 @@ class LimitOffsetPagination(BasePagination):
context = self.get_html_context() context = self.get_html_context()
return template.render(context) return template.render(context)
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
@ -458,14 +496,28 @@ class LimitOffsetPagination(BasePagination):
) )
] ]
def get_count(self, queryset): def get_schema_operation_parameters(self, view):
""" parameters = [
Determine an object count, supporting either querysets or regular lists. {
""" 'name': self.limit_query_param,
try: 'required': False,
return queryset.count() 'in': 'query',
except (AttributeError, TypeError): 'description': force_text(self.limit_query_description),
return len(queryset) 'schema': {
'type': 'integer',
},
},
{
'name': self.offset_query_param,
'required': False,
'in': 'query',
'description': force_text(self.offset_query_description),
'schema': {
'type': 'integer',
},
},
]
return parameters
class CursorPagination(BasePagination): class CursorPagination(BasePagination):
@ -820,3 +872,29 @@ class CursorPagination(BasePagination):
) )
) )
return fields return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.cursor_query_param,
'required': False,
'in': 'query',
'description': force_text(self.cursor_query_description),
'schema': {
'type': 'integer',
},
}
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
}
)
return parameters

View File

@ -193,6 +193,10 @@ class EndpointEnumerator(object):
""" """
Given a URL conf regex, return a URI template string. Given a URL conf regex, return a URI template string.
""" """
# ???: Would it be feasible to adjust this such that we generate the
# path, plus the kwargs, plus the type from the convertor, such that we
# could feed that straight into the parameter schema object?
path = simplify_regex(path_regex) path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format # Strip Django 2.0 convertors as they are incompatible with uritemplate format
@ -232,35 +236,18 @@ class EndpointEnumerator(object):
return [method for method in methods if method not in ('OPTIONS', 'HEAD')] return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class SchemaGenerator(object): class BaseSchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
endpoint_inspector_cls = EndpointEnumerator endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier, # 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas. # so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'. # Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'): if url and not url.endswith('/'):
url += '/' url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns self.patterns = patterns
@ -270,36 +257,15 @@ class SchemaGenerator(object):
self.url = url self.url = url
self.endpoints = None self.endpoints = None
def get_schema(self, request=None, public=False): def _initialise_endpoints(self):
"""
Generate a `coreapi.Document` representing the API schema.
"""
if self.endpoints is None: if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints() self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request) def _get_paths_and_endpoints(self, request):
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
""" """
Return a dictionary containing all the links that should be Generate (path, method, view) given (path, method, callback) for paths.
included in the API schema.
""" """
links = LinkNode()
# Generate (path, method, view) given (path, method, callback).
paths = [] paths = []
view_endpoints = [] view_endpoints = []
for path, method, callback in self.endpoints: for path, method, callback in self.endpoints:
@ -308,22 +274,48 @@ class SchemaGenerator(object):
paths.append(path) paths.append(path)
view_endpoints.append((path, method, view)) view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included return paths, view_endpoints
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints: def create_view(self, callback, method, request=None):
if not self.has_view_permissions(path, method, view): """
continue Given a callback, return an actual view instance.
link = view.schema.get_link(path, method, base_url=self.url) """
subpath = path[len(prefix):] view = callback.cls(**getattr(callback, 'initkwargs', {}))
keys = self.get_keys(subpath, method, view) view.args = ()
insert_into(links, keys, link) view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
return links actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
# Methods used when we generate a view instance from the raw callback... if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def determine_path_prefix(self, paths): def determine_path_prefix(self, paths):
""" """
@ -356,29 +348,6 @@ class SchemaGenerator(object):
prefixes.append('/' + prefix + '/') prefixes.append('/' + prefix + '/')
return common_path(prefixes) return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view): def has_view_permissions(self, path, method, view):
""" """
Return `True` if the incoming request has the correct view permissions. Return `True` if the incoming request has the correct view permissions.
@ -392,23 +361,77 @@ class SchemaGenerator(object):
return False return False
return True return True
def coerce_path(self, path, method, view):
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
""" """
Coerce {pk} path arguments into the name of the model field, Return a dictionary containing all the links that should be
where possible. This is cleaner for an external representation. included in the API schema.
(Ie. "this is an identifier", not "this is a database primary key")
""" """
if not self.coerce_path_pk or '{pk}' not in path: links = LinkNode()
return path
model = getattr(getattr(view, 'queryset', None), 'model', None) paths, view_endpoints = self._get_paths_and_endpoints(request)
if model:
field_name = get_pk_name(model) # Only generate the path prefix for paths that will be included
else: if not paths:
field_name = 'id' return None
return path.replace('{pk}', '{%s}' % field_name) prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout.... # Method for generating the link layout....
def get_keys(self, subpath, method, view): def get_keys(self, subpath, method, view):
""" """
Return a list of keys that should be used to layout a link within Return a list of keys that should be used to layout a link within
@ -452,3 +475,55 @@ class SchemaGenerator(object):
# Default action, eg "/users/", "/users/{pk}/" # Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action] return named_path_components + [action]
class OpenAPISchemaGenerator(BaseSchemaGenerator):
def get_info(self):
info = {
'title': self.title,
'version': 'TODO',
}
if self.description is not None:
info['description'] = self.description
return info
def get_paths(self, request=None):
result = {}
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
subpath = '/' + path[len(prefix):]
result.setdefault(subpath, {})
result[subpath][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
paths = self.get_paths(None if public else request)
if not paths:
return None
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
return schema

View File

@ -179,20 +179,6 @@ class ViewInspector(object):
def view(self): def view(self):
self._view = None self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector): class AutoSchema(ViewInspector):
""" """
@ -213,6 +199,17 @@ class AutoSchema(ViewInspector):
self._manual_fields = manual_fields self._manual_fields = manual_fields
def get_link(self, path, method, base_url): def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method) fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method) fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method) fields += self.get_pagination_fields(path, method)
@ -509,3 +506,246 @@ class DefaultSchema(ViewInspector):
inspector = inspector_class() inspector = inspector_class()
inspector.view = instance inspector.view = instance
return inspector return inspector
class OpenAPIAutoSchema(ViewInspector):
content_types = ['application/json']
def get_operation(self, path, method):
operation = {}
parameters = []
parameters += self._get_path_parameters(path, method)
parameters += self._get_pagination_parameters(path, method)
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self._get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
return operation
def _get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
}
parameters.append(parameter)
return parameters
def _get_filter_parameters(self, path, method):
if not self._allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def _get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_operation_parameters(view)
def _map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self._map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
return {
'type': 'object',
'properties': self._map_serializer(field)
}
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self._map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': {
'enum': list(field.choices)
},
}
if isinstance(field, serializers.ChoiceField):
return {
'enum': list(field.choices),
}
# ListField.
if isinstance(field, serializers.ListField):
return {
'type': 'array',
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.DecimalField: 'number',
serializers.FloatField: 'number',
serializers.IntegerField: 'integer',
serializers.DateField: 'date',
serializers.DateTimeField: 'date-time',
serializers.JSONField: 'object',
serializers.DictField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
# TODO:
# - field is Nested or List serializer.
# - Handle read_only/write_only for request/response differences.
# - could do this with readOnly/writeOnly and then filter dict.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self._map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
properties[field.field_name] = schema
return {
'required': required,
'properties': properties,
}
def _get_request_body(self, path, method):
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return {}
if not hasattr(view, 'get_serializer'):
return {}
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
del content['required']
# No read_only fields for request.
for name, schema in content['properties'].items():
if 'readOnly' in schema:
del content['properties']['name']
return {
'content': {ct: content for ct in self.content_types}
}
def _get_responses(self, path, method):
# TODO: Handle multiple codes.
content = {}
view = self.view
if hasattr(view, 'get_serializer'):
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.Serializer):
content = self._map_serializer(serializer)
# No write_only fields for response.
for name, schema in content['properties'].items():
if 'writeOnly' in schema:
del content['properties']['name']
return {
'200': {
'content': {ct: content for ct in self.content_types}
}
}

View File

@ -56,7 +56,7 @@ DEFAULTS = {
'DEFAULT_FILTER_BACKENDS': (), 'DEFAULT_FILTER_BACKENDS': (),
# Schema # Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema',
# Throttling # Throttling
'DEFAULT_THROTTLE_RATES': { 'DEFAULT_THROTTLE_RATES': {

View File

View File

@ -24,7 +24,8 @@ from rest_framework.utils import formatting
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import BasicModel, ForeignKeySource, ManyToManySource from . import views
from ..models import BasicModel, ForeignKeySource, ManyToManySource
factory = APIRequestFactory() factory = APIRequestFactory()
@ -148,7 +149,7 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas') @override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestRouterGeneratedSchema(TestCase): class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self): def test_anonymous_request(self):
client = APIClient() client = APIClient()
@ -382,30 +383,14 @@ class MethodLimitedViewSet(ExampleViewSet):
http_method_names = ['get', 'head', 'options'] http_method_names = ['get', 'head', 'options']
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGenerator(TestCase): class TestSchemaGenerator(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
url(r'^example/?$', ExampleListView.as_view()), url(r'^example/?$', views.ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()), url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()), url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -453,12 +438,13 @@ class TestSchemaGenerator(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@unittest.skipUnless(path, 'needs Django 2') @unittest.skipUnless(path, 'needs Django 2')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorDjango2(TestCase): class TestSchemaGeneratorDjango2(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
path('example/', ExampleListView.as_view()), path('example/', views.ExampleListView.as_view()),
path('example/<int:pk>/', ExampleDetailView.as_view()), path('example/<int:pk>/', views.ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', ExampleDetailView.as_view()), path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -505,12 +491,13 @@ class TestSchemaGeneratorDjango2(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorNotAtRoot(TestCase): class TestSchemaGeneratorNotAtRoot(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
url(r'^api/v1/example/?$', ExampleListView.as_view()), url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -558,6 +545,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
def setUp(self): def setUp(self):
router = DefaultRouter() router = DefaultRouter()
@ -622,13 +610,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithRestrictedViewSets(TestCase): class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
def setUp(self): def setUp(self):
router = DefaultRouter() router = DefaultRouter()
router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example1', Http404ExampleViewSet, basename='example1')
router.register('example2', PermissionDeniedExampleViewSet, basename='example2') router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
self.patterns = [ self.patterns = [
url('^example/?$', ExampleListView.as_view()), url('^example/?$', views.ExampleListView.as_view()),
url(r'^', include(router.urls)) url(r'^', include(router.urls))
] ]
@ -668,6 +657,7 @@ class ForeignKeySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithForeignKey(TestCase): class TestSchemaGeneratorWithForeignKey(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
@ -713,6 +703,7 @@ class ManyToManySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithManyToMany(TestCase): class TestSchemaGeneratorWithManyToMany(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
@ -747,6 +738,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class Test4605Regression(TestCase): class Test4605Regression(TestCase):
def test_4605_regression(self): def test_4605_regression(self):
generator = SchemaGenerator() generator = SchemaGenerator()
@ -762,6 +754,7 @@ class CustomViewInspector(AutoSchema):
pass pass
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestAutoSchema(TestCase): class TestAutoSchema(TestCase):
def test_apiview_schema_descriptor(self): def test_apiview_schema_descriptor(self):
@ -777,7 +770,7 @@ class TestAutoSchema(TestCase):
assert isinstance(view.schema, CustomViewInspector) assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self): def test_set_custom_inspector_class_via_settings(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}):
view = APIView() view = APIView()
assert isinstance(view.schema, CustomViewInspector) assert isinstance(view.schema, CustomViewInspector)
@ -971,6 +964,7 @@ class TestAutoSchema(TestCase):
self.assertEqual(field_to_schema(case[0]), case[1]) self.assertEqual(field_to_schema(case[0]), case[1])
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_docstring_is_not_stripped_by_get_description(): def test_docstring_is_not_stripped_by_get_description():
class ExampleDocstringAPIView(APIView): class ExampleDocstringAPIView(APIView):
""" """
@ -1014,20 +1008,19 @@ class ExcludedAPIView(APIView):
pass pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class SchemaGenerationExclusionTests(TestCase): class SchemaGenerationExclusionTests(TestCase):
def setUp(self): def setUp(self):
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
self.patterns = [ self.patterns = [
url('^excluded-cbv/$', ExcludedAPIView.as_view()), url('^excluded-cbv/$', ExcludedAPIView.as_view()),
url('^excluded-fbv/$', excluded_fbv), url('^excluded-fbv/$', excluded_fbv),
@ -1078,11 +1071,6 @@ class SchemaGenerationExclusionTests(TestCase):
assert should_include == expected assert should_include == expected
@api_view(["GET"])
def simple_fbv(request):
pass
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
@ -1118,11 +1106,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestURLNamingCollisions(TestCase): class TestURLNamingCollisions(TestCase):
""" """
Ref: https://github.com/encode/django-rest-framework/issues/4704 Ref: https://github.com/encode/django-rest-framework/issues/4704
""" """
def test_manually_routing_nested_routes(self): def test_manually_routing_nested_routes(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [ patterns = [
url(r'^test', simple_fbv), url(r'^test', simple_fbv),
url(r'^test/list/', simple_fbv), url(r'^test/list/', simple_fbv),
@ -1228,6 +1221,10 @@ class TestURLNamingCollisions(TestCase):
def test_url_under_same_key_not_replaced_another(self): def test_url_under_same_key_not_replaced_another(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [ patterns = [
url(r'^test/list/', simple_fbv), url(r'^test/list/', simple_fbv),
url(r'^test/(?P<pk>\d+)/list/', simple_fbv), url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
@ -1303,7 +1300,8 @@ def test_head_and_options_methods_are_excluded():
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestAutoSchemaAllowsFilters(object): @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestAutoSchemaAllowsFilters(TestCase):
class MockAPIView(APIView): class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter] filter_backends = [filters.OrderingFilter]

View File

@ -0,0 +1,118 @@
from django.conf.urls import url
from django.test import RequestFactory, TestCase, override_settings
from rest_framework import filters, pagination
from rest_framework.request import Request
from rest_framework.schemas.generators import OpenAPISchemaGenerator
from rest_framework.schemas.inspectors import OpenAPIAutoSchema
from . import views
def create_request(path):
factory = RequestFactory()
request = Request(factory.get(path))
return request
def create_view(view_cls, method, request):
generator = OpenAPISchemaGenerator()
view = generator.create_view(view_cls.as_view(), method, request)
return view
class TestBasics(TestCase):
def dummy_view(request):
pass
def test_filters(self):
classes = [filters.SearchFilter, filters.OrderingFilter]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
def test_pagination(self):
classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
class TestOperationIntrospection(TestCase):
def test_path_without_parameters(self):
path = '/example/'
method = 'GET'
view = create_view(
views.ExampleListView,
method,
create_request(path)
)
inspector = OpenAPIAutoSchema()
inspector.view = view
operation = inspector.get_operation(path, method)
assert operation == {
'parameters': [],
'responses': {'200': {'content': {'application/json': {}}}},
}
def test_path_with_id_parameter(self):
path = '/example/{id}/'
method = 'GET'
view = create_view(
views.ExampleDetailView,
method,
create_request(path)
)
inspector = OpenAPIAutoSchema()
inspector.view = view
parameters = inspector._get_path_parameters(path, method)
assert parameters == [{
'description': '',
'in': 'path',
'name': 'id',
'required': True,
'schema': {
'type': 'string',
},
}]
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema'})
class TestGenerator(TestCase):
def test_override_settings(self):
assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema)
def test_paths_construction(self):
"""Construction of the `paths` key."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = OpenAPISchemaGenerator(patterns=patterns)
generator._initialise_endpoints()
paths = generator.get_paths()
assert '/example/' in paths
example_operations = paths['/example/']
assert len(example_operations) == 2
assert 'get' in example_operations
assert 'post' in example_operations
def test_schema_construction(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = OpenAPISchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'openapi' in schema
assert 'paths' in schema

19
tests/schemas/views.py Normal file
View File

@ -0,0 +1,19 @@
from rest_framework import permissions
from rest_framework.views import APIView
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass

View File

@ -7,8 +7,8 @@ from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from django.utils import six from django.utils import six
from rest_framework.compat import coreapi from rest_framework.compat import coreapi, yaml
from rest_framework.utils import formatting, json from rest_framework.utils import json
from rest_framework.views import APIView from rest_framework.views import APIView
@ -31,58 +31,21 @@ class GenerateSchemaTests(TestCase):
self.out = six.StringIO() self.out = six.StringIO()
@pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.')
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info:
description: Sample description
title: SampleAPI
version: ''
openapi: 3.0.0
paths:
/:
get:
operationId: list
servers:
- url: http://api.sample.com/
"""
call_command('generateschema', call_command('generateschema',
'--title=SampleAPI', '--title=SampleAPI',
'--url=http://api.sample.com', '--url=http://api.sample.com',
'--description=Sample description', '--description=Sample description',
stdout=self.out) stdout=self.out)
# Check valid YAML was output.
self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) schema = yaml.load(self.out.getvalue())
assert schema['openapi'] == '3.0.2'
def test_renders_openapi_json_schema(self): def test_renders_openapi_json_schema(self):
expected_out = {
"openapi": "3.0.0",
"info": {
"version": "",
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
}
call_command('generateschema', call_command('generateschema',
'--format=openapi-json', '--format=openapi-json',
stdout=self.out) stdout=self.out)
# Check valid JSON was output.
out_json = json.loads(self.out.getvalue()) out_json = json.loads(self.out.getvalue())
assert out_json['openapi'] == '3.0.2'
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema',
'--format=corejson',
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())