Added SchemaGenerator class

This commit is contained in:
Tom Christie 2016-06-15 16:43:50 +01:00
parent 474a23e254
commit 482289695d
4 changed files with 187 additions and 71 deletions

View File

@ -18,16 +18,15 @@ from __future__ import unicode_literals
import itertools import itertools
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
import uritemplate
from django.conf.urls import url from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch from django.core.urlresolvers import NoReverseMatch
from rest_framework import exceptions, renderers, views from rest_framework import exceptions, renderers, views
from rest_framework.compat import coreapi from rest_framework.compat import coreapi
from rest_framework.request import override_method
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
@ -263,63 +262,6 @@ class SimpleRouter(BaseRouter):
return ret return ret
def get_links(self, request=None):
content = {}
for prefix, viewset, basename in self.registry:
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_placeholder = '{' + lookup_url_kwarg + '}'
routes = self.get_routes(viewset)
for route in routes:
url = '/' + route.url.format(
prefix=prefix,
lookup=lookup_placeholder,
trailing_slash=self.trailing_slash
).lstrip('^').rstrip('$')
mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue
for method, action in mapping.items():
link = self.get_link(viewset, url, method, request)
if link is None:
continue # User does not have permissions.
if prefix not in content:
content[prefix] = {}
content[prefix][action] = link
return content
def get_link(self, viewset, url, method, request=None):
view_instance = viewset()
if request is not None:
with override_method(view_instance, request, method.upper()) as request:
try:
view_instance.check_permissions(request)
except exceptions.APIException:
return None
fields = []
for variable in uritemplate.variables(url):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)
if method in ('put', 'patch', 'post'):
cls = view_instance.get_serializer_class()
serializer = cls()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'patch'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)
return coreapi.Link(url=url, action=method, fields=fields)
class DefaultRouter(SimpleRouter): class DefaultRouter(SimpleRouter):
""" """
@ -334,7 +276,7 @@ class DefaultRouter(SimpleRouter):
self.schema_title = kwargs.pop('schema_title', None) self.schema_title = kwargs.pop('schema_title', None)
super(DefaultRouter, self).__init__(*args, **kwargs) super(DefaultRouter, self).__init__(*args, **kwargs)
def get_api_root_view(self): def get_api_root_view(self, schema_urls=None):
""" """
Return a view to use as the API root. Return a view to use as the API root.
""" """
@ -345,10 +287,10 @@ class DefaultRouter(SimpleRouter):
view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
if self.schema_title: if schema_urls and self.schema_title:
assert coreapi, '`coreapi` must be installed for schema support.' assert coreapi, '`coreapi` must be installed for schema support.'
view_renderers += [renderers.CoreJSONRenderer] view_renderers += [renderers.CoreJSONRenderer]
router = self schema_generator = SchemaGenerator(patterns=schema_urls)
class APIRoot(views.APIView): class APIRoot(views.APIView):
_ignore_model_permissions = True _ignore_model_permissions = True
@ -356,10 +298,9 @@ class DefaultRouter(SimpleRouter):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
if request.accepted_renderer.format == 'corejson': if request.accepted_renderer.format == 'corejson':
content = router.get_links(request) schema = schema_generator.get_schema(request)
if not content: if schema is None:
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied()
schema = coreapi.Document(title=router.schema_title, content=content)
return Response(schema) return Response(schema)
ret = OrderedDict() ret = OrderedDict()
@ -388,15 +329,13 @@ class DefaultRouter(SimpleRouter):
Generate the list of URL patterns, including a default root view Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes. for the API, and appending `.json` style format suffixes.
""" """
urls = [] urls = super(DefaultRouter, self).get_urls()
if self.include_root_view: if self.include_root_view:
root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) view = self.get_api_root_view(schema_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
urls.append(root_url) urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls()
urls.extend(default_urls)
if self.include_format_suffixes: if self.include_format_suffixes:
urls = format_suffix_patterns(urls) urls = format_suffix_patterns(urls)

176
rest_framework/schemas.py Normal file
View File

@ -0,0 +1,176 @@
from importlib import import_module
import coreapi
import uritemplate
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six
from rest_framework import exceptions
from rest_framework.request import clone_request
from rest_framework.views import APIView
class SchemaGenerator(object):
default_mapping = {
'get': 'read',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
def __init__(self, schema_title=None, patterns=None, urlconf=None):
if patterns is None and urlconf is not None:
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.schema_title = schema_title
self.endpoints = self.get_api_endpoints(patterns)
def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
view.request = clone_request(request, method)
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
if not endpoints:
return None
# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
insert_into = content
for item in key[:1]:
if item not in insert_into:
insert_into[item] = {}
insert_into = insert_into[item]
insert_into[key[-1]] = link
# Return the schema document.
return coreapi.Document(title=self.schema_title, content=content)
def get_api_endpoints(self, patterns, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
api_endpoints = []
for pattern in patterns:
path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern):
path = self.get_path(path_regex)
callback = pattern.callback
if self.include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
def get_path(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
path = simplify_regex(path_regex)
path = path.replace('<', '{').replace('>', '}')
return path
def include_endpoint(self, path, callback):
"""
Return True if the given endpoint should be included.
"""
cls = getattr(callback, 'cls', None)
if (cls is None) or not issubclass(cls, APIView):
return False
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False
if path == '/':
return False
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()]
return [
method for method in
callback.cls().allowed_methods if method != 'OPTIONS'
]
def get_key(self, path, method, callback):
"""
Return a tuple of strings, indicating the identity to use for a
given endpoint. eg. ('users', 'list').
"""
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item
actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()]
if category:
return (category, action)
return (action,)
def get_link(self, path, method, callback):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
fields = []
for variable in uritemplate.variables(path):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)
if method in ('PUT', 'PATCH', 'POST'):
serializer_class = view.get_serializer_class()
serializer = serializer_class()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)
return coreapi.Link(url=path, action=method.lower(), fields=fields)

View File

@ -98,6 +98,7 @@ class ViewSetMixin(object):
# resolved URL. # resolved URL.
view.cls = cls view.cls = cls
view.suffix = initkwargs.get('suffix', None) view.suffix = initkwargs.get('suffix', None)
view.actions = actions
return csrf_exempt(view) return csrf_exempt(view)
def initialize_request(self, request, *args, **kwargs): def initialize_request(self, request, *args, **kwargs):

View File

@ -257,7 +257,7 @@ class TestNameableRoot(TestCase):
def test_router_has_custom_name(self): def test_router_has_custom_name(self):
expected = 'nameable-root' expected = 'nameable-root'
self.assertEqual(expected, self.urls[0].name) self.assertEqual(expected, self.urls[-1].name)
class TestActionKeywordArgs(TestCase): class TestActionKeywordArgs(TestCase):