Clean up schema generation (#4527)

This commit is contained in:
Tom Christie 2016-09-30 13:29:01 +01:00 committed by GitHub
parent 49ce3d61b7
commit c3a9538ad9
3 changed files with 212 additions and 149 deletions

View File

@ -32,85 +32,66 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView)
class SchemaGenerator(object):
default_mapping = {
'get': 'read',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
known_actions = (
'create', 'read', 'retrieve', 'list',
'update', 'partial_update', 'destroy'
)
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
{'a': {'b': {'c': 123}}}
"""
for key in keys[:-1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = value
if patterns is None and urlconf is not None:
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
def is_custom_action(action):
return action not in set([
'read', 'retrieve', 'list',
'create', 'update', 'partial_update', 'delete', 'destroy'
])
class EndpointInspector(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
# Use the default Django URL conf
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
else:
urls = urlconf
self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
self.patterns = urls.urlpatterns
else:
self.patterns = patterns
# Load the given URLconf module
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
if url and not url.endswith('/'):
url += '/'
self.patterns = patterns
self.title = title
self.url = url
self.endpoints = None
def get_schema(self, request=None):
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)
links = []
for path, method, category, action, callback in self.endpoints:
view = self.setup_view(callback, method, request)
if self.should_include_link(path, method, callback, view):
link = self.get_link(path, method, callback, view)
links.append((category, action, link))
if not links:
return None
# Generate the schema content structure, eg:
# {'users': {'list': Link()}}
content = {}
for category, action, link in links:
if category is None:
content[action] = link
elif category in content:
content[category][action] = link
else:
content[category] = {action: link}
# Return the schema document.
return coreapi.Document(title=self.title, content=content, url=self.url)
def get_api_endpoints(self, patterns, prefix=''):
def get_api_endpoints(self, patterns=None, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern):
path = self.get_path(path_regex)
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
action = self.get_action(path, method, callback)
category = self.get_category(path, method, callback, action)
endpoint = (path, method, category, action, callback)
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
@ -122,7 +103,7 @@ class SchemaGenerator(object):
return api_endpoints
def get_path(self, path_regex):
def get_path_from_regex(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
@ -157,47 +138,60 @@ class SchemaGenerator(object):
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
def get_action(self, path, method, callback):
"""
Return a descriptive action string for the endpoint, eg. 'list'.
"""
actions = getattr(callback, 'actions', self.default_mapping)
return actions[method.lower()]
def get_category(self, path, method, callback, action):
class SchemaGenerator(object):
# Map methods onto 'actions' that are the names used in the link layout.
default_mapping = {
'get': 'read',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Coerce the following viewset actions into different names.
coerce_actions = {
'retrieve': 'read',
'destroy': 'delete'
}
endpoint_inspector_cls = EndpointInspector
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
if url and not url.endswith('/'):
url += '/'
self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf)
self.title = title
self.url = url
self.endpoints = None
def get_schema(self, request=None):
"""
Return a descriptive category string for the endpoint, eg. 'users'.
Examples of category/action pairs that should be generated for various
endpoints:
/users/ [users][list], [users][create]
/users/{pk}/ [users][read], [users][update], [users][destroy]
/users/enabled/ [users][enabled] (custom action)
/users/{pk}/star/ [users][star] (custom action)
/users/{pk}/groups/ [groups][list], [groups][create]
/users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy]
Generate a `coreapi.Document` representing the API schema.
"""
path_components = path.strip('/').split('/')
path_components = [
component for component in path_components
if '{' not in component
]
if action in self.known_actions:
# Default action, eg "/users/", "/users/{pk}/"
idx = -1
else:
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
idx = -2
if self.endpoints is None:
self.endpoints = self.endpoint_inspector.get_api_endpoints()
try:
return path_components[idx]
except IndexError:
links = {}
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
if not self.has_view_permissions(view):
continue
link = self.get_link(path, method, view)
keys = self.get_keys(path, method, view)
insert_into(links, keys, link)
if not links:
return None
def setup_view(self, callback, method, request):
return coreapi.Document(title=self.title, url=self.url, content=links)
# Methods used when we generate a view instance from the raw callback...
def create_view(self, callback, method, request=None):
"""
Setup a view instance.
Given a callback, return an actual view instance.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
@ -205,6 +199,7 @@ class SchemaGenerator(object):
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
actions = getattr(callback, 'actions', None)
if actions is not None:
@ -215,14 +210,13 @@ class SchemaGenerator(object):
if request is not None:
view.request = clone_request(request, method)
else:
view.request = None
return view
# Methods for generating each individual `Link` instance...
def should_include_link(self, path, method, callback, view):
def has_view_permissions(self, view):
"""
Return `True` if the incoming request has the correct view permissions.
"""
if view.request is None:
return True
@ -230,20 +224,35 @@ class SchemaGenerator(object):
view.check_permissions(view.request)
except exceptions.APIException:
return False
return True
def get_link(self, path, method, callback, view):
def is_list_endpoint(self, path, method, view):
"""
Return True if the given path/method appears to represent a list endpoint.
"""
if hasattr(view, 'action'):
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)
fields += self.get_filter_fields(path, method, callback, view)
fields = self.get_path_fields(path, method, view)
fields += self.get_serializer_fields(path, method, view)
fields += self.get_pagination_fields(path, method, view)
fields += self.get_filter_fields(path, method, view)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method, callback, view)
encoding = self.get_encoding(path, method, view)
else:
encoding = None
@ -257,7 +266,7 @@ class SchemaGenerator(object):
fields=fields
)
def get_encoding(self, path, method, callback, view):
def get_encoding(self, path, method, view):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
@ -278,7 +287,7 @@ class SchemaGenerator(object):
return None
def get_path_fields(self, path, method, callback, view):
def get_path_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
@ -291,7 +300,7 @@ class SchemaGenerator(object):
return fields
def get_serializer_fields(self, path, method, callback, view):
def get_serializer_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
@ -327,11 +336,8 @@ class SchemaGenerator(object):
return fields
def get_pagination_fields(self, path, method, callback, view):
if method != 'GET':
return []
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
def get_pagination_fields(self, path, method, view):
if not self.is_list_endpoint(path, method, view):
return []
if not getattr(view, 'pagination_class', None):
@ -340,17 +346,54 @@ class SchemaGenerator(object):
paginator = view.pagination_class()
return as_query_fields(paginator.get_fields(view))
def get_filter_fields(self, path, method, callback, view):
if method != 'GET':
def get_filter_fields(self, path, method, view):
if not self.is_list_endpoint(path, method, view):
return []
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
return []
if not hasattr(view, 'filter_backends'):
if not getattr(view, 'filter_backends', None):
return []
fields = []
for filter_backend in view.filter_backends:
fields += as_query_fields(filter_backend().get_fields(view))
return fields
# Method for generating the link layout....
def get_keys(self, path, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
if view.action in self.coerce_actions:
action = self.coerce_actions[view.action]
else:
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if self.is_list_endpoint(path, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in path.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
return named_path_components[:-1] + [action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

View File

@ -130,6 +130,7 @@ class APIView(View):
view = super(APIView, cls).as_view(**initkwargs)
view.cls = cls
view.initkwargs = initkwargs
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.

View File

@ -6,7 +6,6 @@ from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
from rest_framework.test import APIClient
@ -55,24 +54,11 @@ class ExampleViewSet(ModelViewSet):
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, request, *args, **kwargs):
return Response()
def post(self, request, *args, **kwargs):
return Response()
router = DefaultRouter(schema_title='Example API' if coreapi else None)
router.register('example', ExampleViewSet, base_name='example')
urlpatterns = [
url(r'^', include(router.urls))
]
urlpatterns2 = [
url(r'^example-view/$', ExampleView.as_view(), name='example-view')
]
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@ -99,7 +85,7 @@ class TestRouterGeneratedSchema(TestCase):
url='/example/custom_list_action/',
action='get'
),
'retrieve': coreapi.Link(
'read': coreapi.Link(
url='/example/{pk}/',
action='get',
fields=[
@ -138,7 +124,7 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('b', required=False, location='form')
]
),
'retrieve': coreapi.Link(
'read': coreapi.Link(
url='/example/{pk}/',
action='get',
fields=[
@ -179,7 +165,7 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('b', required=False, location='form')
]
),
'destroy': coreapi.Link(
'delete': coreapi.Link(
url='/example/{pk}/',
action='delete',
fields=[
@ -192,25 +178,58 @@ class TestRouterGeneratedSchema(TestCase):
self.assertEqual(response.data, expected)
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaGenerator(TestCase):
def test_view(self):
schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2)
schema = schema_generator.get_schema()
def setUp(self):
self.patterns = [
url('^example/?$', ExampleListView.as_view()),
url('^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
"""
Ensure that schema generation works for APIView classes.
"""
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
url='',
title='Test View',
title='Example API',
content={
'example-view': {
'example': {
'create': coreapi.Link(
url='/example-view/',
url='/example/',
action='post',
fields=[]
),
'read': coreapi.Link(
url='/example-view/',
'list': coreapi.Link(
url='/example/',
action='get',
fields=[]
),
'read': coreapi.Link(
url='/example/{pk}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
}
}