OpenAPI: Ported docstring operation description from CoreAPI inspector. (#6898)

This commit is contained in:
Yann Savary 2019-11-06 21:52:02 +01:00 committed by Carlton Gibson
parent becb962160
commit 7c3477dcda
5 changed files with 100 additions and 61 deletions

View File

@ -1,26 +1,18 @@
import re
import warnings import warnings
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from urllib import parse from urllib import parse
from django.db import models from django.db import models
from django.utils.encoding import force_str, smart_text from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
def common_path(paths): def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths] split_paths = [path.strip('/').split('/') for path in paths]
@ -397,44 +389,6 @@ class AutoSchema(ViewInspector):
description=description description=description
) )
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method): def get_path_fields(self, path, method):
""" """
Return a list of `coreapi.Field` instances corresponding to any Return a list of `coreapi.Field` instances corresponding to any

View File

@ -3,9 +3,13 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
import re
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
from django.utils.encoding import smart_text
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector: class ViewInspector:
@ -15,6 +19,9 @@ class ViewInspector:
Provide subclass for per-view schema generation Provide subclass for per-view schema generation
""" """
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self): def __init__(self):
self.instance_schemas = WeakKeyDictionary() self.instance_schemas = WeakKeyDictionary()
@ -62,6 +69,45 @@ class ViewInspector:
def view(self): def view(self):
self._view = None self._view = None
def get_description(self, path, method):
"""
Determine a path description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -17,8 +17,6 @@ from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Generator
class SchemaGenerator(BaseSchemaGenerator): class SchemaGenerator(BaseSchemaGenerator):
@ -94,6 +92,7 @@ class AutoSchema(ViewInspector):
operation = {} operation = {}
operation['operationId'] = self._get_operation_id(path, method) operation['operationId'] = self._get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = [] parameters = []
parameters += self._get_path_parameters(path, method) parameters += self._get_path_parameters(path, method)

View File

@ -77,7 +77,7 @@ class TestOperationIntrospection(TestCase):
method = 'GET' method = 'GET'
view = create_view( view = create_view(
views.ExampleListView, views.DocStringExampleListView,
method, method,
create_request(path) create_request(path)
) )
@ -86,7 +86,8 @@ class TestOperationIntrospection(TestCase):
operation = inspector.get_operation(path, method) operation = inspector.get_operation(path, method)
assert operation == { assert operation == {
'operationId': 'listExamples', 'operationId': 'listDocStringExamples',
'description': 'A description of my GET operation.',
'parameters': [], 'parameters': [],
'responses': { 'responses': {
'200': { '200': {
@ -108,15 +109,18 @@ class TestOperationIntrospection(TestCase):
method = 'GET' method = 'GET'
view = create_view( view = create_view(
views.ExampleDetailView, views.DocStringExampleDetailView,
method, method,
create_request(path) create_request(path)
) )
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
parameters = inspector._get_path_parameters(path, method) operation = inspector.get_operation(path, method)
assert parameters == [{ assert operation == {
'operationId': 'RetrieveDocStringExampleDetail',
'description': 'A description of my GET operation.',
'parameters': [{
'description': '', 'description': '',
'in': 'path', 'in': 'path',
'name': 'id', 'name': 'id',
@ -124,7 +128,19 @@ class TestOperationIntrospection(TestCase):
'schema': { 'schema': {
'type': 'string', 'type': 'string',
}, },
}] }],
'responses': {
'200': {
'description': '',
'content': {
'application/json': {
'schema': {
},
},
},
},
},
}
def test_request_body(self): def test_request_body(self):
path = '/' path = '/'

View File

@ -29,6 +29,30 @@ class ExampleDetailView(APIView):
pass pass
class DocStringExampleListView(APIView):
"""
get: A description of my GET operation.
post: A description of my POST operation.
"""
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class DocStringExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
"""
A description of my GET operation.
"""
pass
# Generics. # Generics.
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
date = serializers.DateField() date = serializers.DateField()