Added SchemaGenerator class

This commit is contained in:
Tom Christie 2016-06-15 16:43:50 +01:00
parent 474a23e254
commit 482289695d
4 changed files with 187 additions and 71 deletions

View File

@ -18,16 +18,15 @@ from __future__ import unicode_literals
import itertools
from collections import OrderedDict, namedtuple
import uritemplate
from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch
from rest_framework import exceptions, renderers, views
from rest_framework.compat import coreapi
from rest_framework.request import override_method
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
@ -263,63 +262,6 @@ class SimpleRouter(BaseRouter):
return ret
def get_links(self, request=None):
content = {}
for prefix, viewset, basename in self.registry:
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_placeholder = '{' + lookup_url_kwarg + '}'
routes = self.get_routes(viewset)
for route in routes:
url = '/' + route.url.format(
prefix=prefix,
lookup=lookup_placeholder,
trailing_slash=self.trailing_slash
).lstrip('^').rstrip('$')
mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue
for method, action in mapping.items():
link = self.get_link(viewset, url, method, request)
if link is None:
continue # User does not have permissions.
if prefix not in content:
content[prefix] = {}
content[prefix][action] = link
return content
def get_link(self, viewset, url, method, request=None):
view_instance = viewset()
if request is not None:
with override_method(view_instance, request, method.upper()) as request:
try:
view_instance.check_permissions(request)
except exceptions.APIException:
return None
fields = []
for variable in uritemplate.variables(url):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)
if method in ('put', 'patch', 'post'):
cls = view_instance.get_serializer_class()
serializer = cls()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'patch'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)
return coreapi.Link(url=url, action=method, fields=fields)
class DefaultRouter(SimpleRouter):
"""
@ -334,7 +276,7 @@ class DefaultRouter(SimpleRouter):
self.schema_title = kwargs.pop('schema_title', None)
super(DefaultRouter, self).__init__(*args, **kwargs)
def get_api_root_view(self):
def get_api_root_view(self, schema_urls=None):
"""
Return a view to use as the API root.
"""
@ -345,10 +287,10 @@ class DefaultRouter(SimpleRouter):
view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
if self.schema_title:
if schema_urls and self.schema_title:
assert coreapi, '`coreapi` must be installed for schema support.'
view_renderers += [renderers.CoreJSONRenderer]
router = self
schema_generator = SchemaGenerator(patterns=schema_urls)
class APIRoot(views.APIView):
_ignore_model_permissions = True
@ -356,10 +298,9 @@ class DefaultRouter(SimpleRouter):
def get(self, request, *args, **kwargs):
if request.accepted_renderer.format == 'corejson':
content = router.get_links(request)
if not content:
schema = schema_generator.get_schema(request)
if schema is None:
raise exceptions.PermissionDenied()
schema = coreapi.Document(title=router.schema_title, content=content)
return Response(schema)
ret = OrderedDict()
@ -388,15 +329,13 @@ class DefaultRouter(SimpleRouter):
Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes.
"""
urls = []
urls = super(DefaultRouter, self).get_urls()
if self.include_root_view:
root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
view = self.get_api_root_view(schema_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls()
urls.extend(default_urls)
if self.include_format_suffixes:
urls = format_suffix_patterns(urls)

176
rest_framework/schemas.py Normal file
View File

@ -0,0 +1,176 @@
from importlib import import_module
import coreapi
import uritemplate
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six
from rest_framework import exceptions
from rest_framework.request import clone_request
from rest_framework.views import APIView
class SchemaGenerator(object):
default_mapping = {
'get': 'read',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
def __init__(self, schema_title=None, patterns=None, urlconf=None):
if patterns is None and urlconf is not None:
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.schema_title = schema_title
self.endpoints = self.get_api_endpoints(patterns)
def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
view.request = clone_request(request, method)
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
if not endpoints:
return None
# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
insert_into = content
for item in key[:1]:
if item not in insert_into:
insert_into[item] = {}
insert_into = insert_into[item]
insert_into[key[-1]] = link
# Return the schema document.
return coreapi.Document(title=self.schema_title, content=content)
def get_api_endpoints(self, patterns, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
api_endpoints = []
for pattern in patterns:
path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern):
path = self.get_path(path_regex)
callback = pattern.callback
if self.include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
def get_path(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
path = simplify_regex(path_regex)
path = path.replace('<', '{').replace('>', '}')
return path
def include_endpoint(self, path, callback):
"""
Return True if the given endpoint should be included.
"""
cls = getattr(callback, 'cls', None)
if (cls is None) or not issubclass(cls, APIView):
return False
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False
if path == '/':
return False
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()]
return [
method for method in
callback.cls().allowed_methods if method != 'OPTIONS'
]
def get_key(self, path, method, callback):
"""
Return a tuple of strings, indicating the identity to use for a
given endpoint. eg. ('users', 'list').
"""
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item
actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()]
if category:
return (category, action)
return (action,)
def get_link(self, path, method, callback):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
fields = []
for variable in uritemplate.variables(path):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)
if method in ('PUT', 'PATCH', 'POST'):
serializer_class = view.get_serializer_class()
serializer = serializer_class()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)
return coreapi.Link(url=path, action=method.lower(), fields=fields)

View File

@ -98,6 +98,7 @@ class ViewSetMixin(object):
# resolved URL.
view.cls = cls
view.suffix = initkwargs.get('suffix', None)
view.actions = actions
return csrf_exempt(view)
def initialize_request(self, request, *args, **kwargs):

View File

@ -257,7 +257,7 @@ class TestNameableRoot(TestCase):
def test_router_has_custom_name(self):
expected = 'nameable-root'
self.assertEqual(expected, self.urls[0].name)
self.assertEqual(expected, self.urls[-1].name)
class TestActionKeywordArgs(TestCase):