Moved docstring operation description to ViewInspector

This commit is contained in:
Yann Savary 2019-09-05 13:54:42 +02:00
parent 9e7fa1a71e
commit c147769c1b
3 changed files with 48 additions and 95 deletions

View File

@ -1,26 +1,18 @@
import re
import warnings import warnings
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from urllib import parse from urllib import parse
from django.db import models from django.db import models
from django.utils.encoding import force_str, smart_text from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
def common_path(paths): def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths] split_paths = [path.strip('/').split('/') for path in paths]
@ -397,44 +389,6 @@ class AutoSchema(ViewInspector):
description=description description=description
) )
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method): def get_path_fields(self, path, method):
""" """
Return a list of `coreapi.Field` instances corresponding to any Return a list of `coreapi.Field` instances corresponding to any

View File

@ -3,9 +3,13 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
import re
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
from django.utils.encoding import smart_text
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector: class ViewInspector:
@ -15,6 +19,9 @@ class ViewInspector:
Provide subclass for per-view schema generation Provide subclass for per-view schema generation
""" """
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self): def __init__(self):
self.instance_schemas = WeakKeyDictionary() self.instance_schemas = WeakKeyDictionary()
@ -62,6 +69,45 @@ class ViewInspector:
def view(self): def view(self):
self._view = None self._view = None
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -1,4 +1,3 @@
import re
import warnings import warnings
from urllib.parse import urljoin from urllib.parse import urljoin
@ -7,24 +6,16 @@ from django.core.validators import (
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
) )
from django.db import models from django.db import models
from django.utils.encoding import force_str, smart_text from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator
class SchemaGenerator(BaseSchemaGenerator): class SchemaGenerator(BaseSchemaGenerator):
@ -94,44 +85,6 @@ class AutoSchema(ViewInspector):
'delete': 'Destroy', 'delete': 'Destroy',
} }
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_operation(self, path, method): def get_operation(self, path, method):
operation = {} operation = {}