mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 12:30:11 +03:00
Split generators, inspectors, views.
This commit is contained in:
parent
9fa8a05b34
commit
18575c9f5f
|
@ -1,135 +1,8 @@
|
|||
import re
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db import models
|
||||
from django.http import Http404
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework import exceptions, renderers, serializers
|
||||
from rest_framework.compat import (
|
||||
RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate,
|
||||
urlparse
|
||||
)
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
from rest_framework.views import APIView
|
||||
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.String(),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices.keys())),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices.keys()),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
split_paths = [path.strip('/').split('/') for path in paths]
|
||||
s1 = min(split_paths)
|
||||
s2 = max(split_paths)
|
||||
common = s1
|
||||
for i, c in enumerate(s1):
|
||||
if c != s2[i]:
|
||||
common = s1[:i]
|
||||
break
|
||||
return '/' + '/'.join(common)
|
||||
|
||||
|
||||
def get_pk_name(model):
|
||||
meta = model._meta.concrete_model._meta
|
||||
return _get_pk(meta).name
|
||||
|
||||
|
||||
def is_api_view(callback):
|
||||
"""
|
||||
Return `True` if the given view callback is a REST framework view/viewset.
|
||||
"""
|
||||
cls = getattr(callback, 'cls', None)
|
||||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
{'a': {'b': {'c': 123}}}
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = value
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in set([
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
])
|
||||
# The API we expose
|
||||
# from .views import get_schema_view
|
||||
|
||||
|
||||
# Shared function. TODO: move to utils.
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
|
@ -144,677 +17,3 @@ def is_list_view(path, method, view):
|
|||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
'GET': 0,
|
||||
'POST': 1,
|
||||
'PUT': 2,
|
||||
'PATCH': 3,
|
||||
'DELETE': 4
|
||||
}.get(method, 5)
|
||||
return (path, method_priority)
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
||||
|
||||
class EndpointInspector(object):
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urlconf = settings.ROOT_URLCONF
|
||||
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, six.string_types):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + pattern.regex.pattern
|
||||
if isinstance(pattern, RegexURLPattern):
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
nested_endpoints = self.get_api_endpoints(
|
||||
patterns=pattern.url_patterns,
|
||||
prefix=path_regex
|
||||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
|
||||
|
||||
return api_endpoints
|
||||
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
path = simplify_regex(path_regex)
|
||||
path = path.replace('<', '{').replace('>', '}')
|
||||
return path
|
||||
|
||||
def should_include_endpoint(self, path, callback):
|
||||
"""
|
||||
Return `True` if the given endpoint should be included.
|
||||
"""
|
||||
if not is_api_view(callback):
|
||||
return False # Ignore anything except REST framework views.
|
||||
|
||||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False # Ignore .json style URLs.
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
"""
|
||||
Return a list of the valid HTTP methods for this endpoint.
|
||||
"""
|
||||
if hasattr(callback, 'actions'):
|
||||
actions = set(callback.actions.keys())
|
||||
http_method_names = set(callback.cls.http_method_names)
|
||||
return [method.upper() for method in actions & http_method_names]
|
||||
|
||||
return [
|
||||
method for method in
|
||||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||
]
|
||||
|
||||
|
||||
class ViewInspector(object):
|
||||
"""
|
||||
Descriptor class on APIView.
|
||||
|
||||
Provide subclass for per-view schema generation
|
||||
"""
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Enables `ViewInspector` as a Python _Descriptor_.
|
||||
|
||||
This is how `view.schema` knows about `view`.
|
||||
|
||||
`__get__` is called when the descriptor is accessed on the owner.
|
||||
(That will be when view.schema is called in our case.)
|
||||
|
||||
`owner` is always the owner class. (An APIView, or subclass for us.)
|
||||
`instance` is the view instance or `None` if accessed from the class,
|
||||
rather than an instance.
|
||||
|
||||
See: https://docs.python.org/3/howto/descriptor.html for info on
|
||||
descriptor usage.
|
||||
"""
|
||||
self.view = instance
|
||||
return self
|
||||
|
||||
@property
|
||||
def view(self):
|
||||
"""View property."""
|
||||
assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)"
|
||||
return self._view
|
||||
|
||||
@view.setter
|
||||
def view(self, value):
|
||||
self._view = value
|
||||
|
||||
@view.deleter
|
||||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view instrospection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
if self._manual_fields is not None:
|
||||
by_name = {f.name: f for f in fields}
|
||||
for f in self._manual_fields:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=urlparse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return formatting.dedent(smart_text(method_docstring))
|
||||
|
||||
description = view.get_view_description()
|
||||
lines = [line.strip() for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
header = getattr(view, 'action', method.lower())
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if not getattr(view, 'filter_backends', None):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(view)
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = set((
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
))
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
# Note: With `AutoSchema` defined we attach it to APIView.
|
||||
# * We do this here to avoid the dependency cycle from SchemaView needing
|
||||
# APIView (below).
|
||||
# * This requires importing _something_ from `rest_framework.schemas` or
|
||||
# `rest_framework.documentation` before `APIView.schema will be available.
|
||||
# * ???: When would `APIView.schema` be needed and that NOT be the case?
|
||||
# * The alternative is to import AutoSchema to `views`, make `schemas` a
|
||||
# package, and move SchemaView to `schema.views`, importing APIView there.
|
||||
APIView.schema = AutoSchema()
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Overrides get_link to return manually specified schema.
|
||||
"""
|
||||
def __init__(self, link):
|
||||
assert isinstance(link, coreapi.Link)
|
||||
self._link = link
|
||||
|
||||
def get_link(self, *args):
|
||||
return self._link
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
endpoint_inspector_cls = EndpointInspector
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
self.urlconf = urlconf
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = OrderedDict()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
if getattr(view, 'exclude_from_schema', False):
|
||||
continue
|
||||
path = self.coerce_path(path, method, view)
|
||||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
return links
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
Given a list of all paths, return the common prefix which should be
|
||||
discounted when generating a schema structure.
|
||||
|
||||
This will be the longest common string that does not include that last
|
||||
component of the URL, or the last component before a path parameter.
|
||||
|
||||
For example:
|
||||
|
||||
/api/v1/users/
|
||||
/api/v1/users/{pk}/
|
||||
|
||||
The path prefix is '/api/v1/'
|
||||
"""
|
||||
prefixes = []
|
||||
for path in paths:
|
||||
components = path.strip('/').split('/')
|
||||
initial_components = []
|
||||
for component in components:
|
||||
if '{' in component:
|
||||
break
|
||||
initial_components.append(component)
|
||||
prefix = '/'.join(initial_components[:-1])
|
||||
if not prefix:
|
||||
# We can just break early in the case that there's at least
|
||||
# one URL that doesn't have a path prefix.
|
||||
return '/'
|
||||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except (exceptions.APIException, Http404, PermissionDenied):
|
||||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
if len(view.action_map) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = None
|
||||
schema_generator = None
|
||||
public = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SchemaView, self).__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes = [
|
||||
renderers.CoreJSONRenderer,
|
||||
renderers.BrowsableAPIRenderer,
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [renderers.CoreJSONRenderer]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = self.schema_generator.get_schema(request, self.public)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=SchemaGenerator):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns,
|
||||
)
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
)
|
||||
|
|
386
rest_framework/schemas/generators.py
Normal file
386
rest_framework/schemas/generators.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
from django.utils import six
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.compat import (
|
||||
RegexURLPattern, RegexURLResolver, coreapi, coreschema
|
||||
)
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from . import is_list_view
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
split_paths = [path.strip('/').split('/') for path in paths]
|
||||
s1 = min(split_paths)
|
||||
s2 = max(split_paths)
|
||||
common = s1
|
||||
for i, c in enumerate(s1):
|
||||
if c != s2[i]:
|
||||
common = s1[:i]
|
||||
break
|
||||
return '/' + '/'.join(common)
|
||||
|
||||
|
||||
def get_pk_name(model):
|
||||
meta = model._meta.concrete_model._meta
|
||||
return _get_pk(meta).name
|
||||
|
||||
|
||||
def is_api_view(callback):
|
||||
"""
|
||||
Return `True` if the given view callback is a REST framework view/viewset.
|
||||
"""
|
||||
cls = getattr(callback, 'cls', None)
|
||||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
{'a': {'b': {'c': 123}}}
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = value
|
||||
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in set([
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
])
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
'GET': 0,
|
||||
'POST': 1,
|
||||
'PUT': 2,
|
||||
'PATCH': 3,
|
||||
'DELETE': 4
|
||||
}.get(method, 5)
|
||||
return (path, method_priority)
|
||||
|
||||
|
||||
class EndpointInspector(object):
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urlconf = settings.ROOT_URLCONF
|
||||
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, six.string_types):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + pattern.regex.pattern
|
||||
if isinstance(pattern, RegexURLPattern):
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
nested_endpoints = self.get_api_endpoints(
|
||||
patterns=pattern.url_patterns,
|
||||
prefix=path_regex
|
||||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
|
||||
|
||||
return api_endpoints
|
||||
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
path = simplify_regex(path_regex)
|
||||
path = path.replace('<', '{').replace('>', '}')
|
||||
return path
|
||||
|
||||
def should_include_endpoint(self, path, callback):
|
||||
"""
|
||||
Return `True` if the given endpoint should be included.
|
||||
"""
|
||||
if not is_api_view(callback):
|
||||
return False # Ignore anything except REST framework views.
|
||||
|
||||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False # Ignore .json style URLs.
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
"""
|
||||
Return a list of the valid HTTP methods for this endpoint.
|
||||
"""
|
||||
if hasattr(callback, 'actions'):
|
||||
actions = set(callback.actions.keys())
|
||||
http_method_names = set(callback.cls.http_method_names)
|
||||
return [method.upper() for method in actions & http_method_names]
|
||||
|
||||
return [
|
||||
method for method in
|
||||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||
]
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
endpoint_inspector_cls = EndpointInspector
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
self.urlconf = urlconf
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = OrderedDict()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
if getattr(view, 'exclude_from_schema', False):
|
||||
continue
|
||||
path = self.coerce_path(path, method, view)
|
||||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
return links
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
Given a list of all paths, return the common prefix which should be
|
||||
discounted when generating a schema structure.
|
||||
|
||||
This will be the longest common string that does not include that last
|
||||
component of the URL, or the last component before a path parameter.
|
||||
|
||||
For example:
|
||||
|
||||
/api/v1/users/
|
||||
/api/v1/users/{pk}/
|
||||
|
||||
The path prefix is '/api/v1/'
|
||||
"""
|
||||
prefixes = []
|
||||
for path in paths:
|
||||
components = path.strip('/').split('/')
|
||||
initial_components = []
|
||||
for component in components:
|
||||
if '{' in component:
|
||||
break
|
||||
initial_components.append(component)
|
||||
prefix = '/'.join(initial_components[:-1])
|
||||
if not prefix:
|
||||
# We can just break early in the case that there's at least
|
||||
# one URL that doesn't have a path prefix.
|
||||
return '/'
|
||||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except (exceptions.APIException, Http404, PermissionDenied):
|
||||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
if len(view.action_map) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
374
rest_framework/schemas/inspectors.py
Normal file
374
rest_framework/schemas/inspectors.py
Normal file
|
@ -0,0 +1,374 @@
|
|||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
from . import is_list_view
|
||||
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.String(),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices.keys())),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices.keys()),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
||||
|
||||
class ViewInspector(object):
|
||||
"""
|
||||
Descriptor class on APIView.
|
||||
|
||||
Provide subclass for per-view schema generation
|
||||
"""
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Enables `ViewInspector` as a Python _Descriptor_.
|
||||
|
||||
This is how `view.schema` knows about `view`.
|
||||
|
||||
`__get__` is called when the descriptor is accessed on the owner.
|
||||
(That will be when view.schema is called in our case.)
|
||||
|
||||
`owner` is always the owner class. (An APIView, or subclass for us.)
|
||||
`instance` is the view instance or `None` if accessed from the class,
|
||||
rather than an instance.
|
||||
|
||||
See: https://docs.python.org/3/howto/descriptor.html for info on
|
||||
descriptor usage.
|
||||
"""
|
||||
self.view = instance
|
||||
return self
|
||||
|
||||
@property
|
||||
def view(self):
|
||||
"""View property."""
|
||||
assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)"
|
||||
return self._view
|
||||
|
||||
@view.setter
|
||||
def view(self, value):
|
||||
self._view = value
|
||||
|
||||
@view.deleter
|
||||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view instrospection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
if self._manual_fields is not None:
|
||||
by_name = {f.name: f for f in fields}
|
||||
for f in self._manual_fields:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=urlparse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return formatting.dedent(smart_text(method_docstring))
|
||||
|
||||
description = view.get_view_description()
|
||||
lines = [line.strip() for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
header = getattr(view, 'action', method.lower())
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if not getattr(view, 'filter_backends', None):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(view)
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = set((
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
))
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Overrides get_link to return manually specified schema.
|
||||
"""
|
||||
def __init__(self, link):
|
||||
assert isinstance(link, coreapi.Link)
|
||||
self._link = link
|
||||
|
||||
def get_link(self, *args):
|
||||
return self._link
|
47
rest_framework/schemas/views.py
Normal file
47
rest_framework/schemas/views.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from rest_framework import exceptions, renderers
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas.generators import SchemaGenerator
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = None
|
||||
schema_generator = None
|
||||
public = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SchemaView, self).__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes = [
|
||||
renderers.CoreJSONRenderer,
|
||||
renderers.BrowsableAPIRenderer,
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [renderers.CoreJSONRenderer]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = self.schema_generator.get_schema(request, self.public)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=SchemaGenerator):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns,
|
||||
)
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
)
|
Loading…
Reference in New Issue
Block a user