mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 22:04:48 +03:00
Added SchemaGenerator class
This commit is contained in:
parent
474a23e254
commit
482289695d
|
@ -18,16 +18,15 @@ from __future__ import unicode_literals
|
||||||
import itertools
|
import itertools
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
|
|
||||||
import uritemplate
|
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.core.urlresolvers import NoReverseMatch
|
from django.core.urlresolvers import NoReverseMatch
|
||||||
|
|
||||||
from rest_framework import exceptions, renderers, views
|
from rest_framework import exceptions, renderers, views
|
||||||
from rest_framework.compat import coreapi
|
from rest_framework.compat import coreapi
|
||||||
from rest_framework.request import override_method
|
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.reverse import reverse
|
from rest_framework.reverse import reverse
|
||||||
|
from rest_framework.schemas import SchemaGenerator
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.urlpatterns import format_suffix_patterns
|
from rest_framework.urlpatterns import format_suffix_patterns
|
||||||
|
|
||||||
|
@ -263,63 +262,6 @@ class SimpleRouter(BaseRouter):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_links(self, request=None):
|
|
||||||
content = {}
|
|
||||||
|
|
||||||
for prefix, viewset, basename in self.registry:
|
|
||||||
lookup_field = getattr(viewset, 'lookup_field', 'pk')
|
|
||||||
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
|
|
||||||
lookup_placeholder = '{' + lookup_url_kwarg + '}'
|
|
||||||
|
|
||||||
routes = self.get_routes(viewset)
|
|
||||||
|
|
||||||
for route in routes:
|
|
||||||
url = '/' + route.url.format(
|
|
||||||
prefix=prefix,
|
|
||||||
lookup=lookup_placeholder,
|
|
||||||
trailing_slash=self.trailing_slash
|
|
||||||
).lstrip('^').rstrip('$')
|
|
||||||
|
|
||||||
mapping = self.get_method_map(viewset, route.mapping)
|
|
||||||
if not mapping:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for method, action in mapping.items():
|
|
||||||
link = self.get_link(viewset, url, method, request)
|
|
||||||
if link is None:
|
|
||||||
continue # User does not have permissions.
|
|
||||||
if prefix not in content:
|
|
||||||
content[prefix] = {}
|
|
||||||
content[prefix][action] = link
|
|
||||||
return content
|
|
||||||
|
|
||||||
def get_link(self, viewset, url, method, request=None):
|
|
||||||
view_instance = viewset()
|
|
||||||
if request is not None:
|
|
||||||
with override_method(view_instance, request, method.upper()) as request:
|
|
||||||
try:
|
|
||||||
view_instance.check_permissions(request)
|
|
||||||
except exceptions.APIException:
|
|
||||||
return None
|
|
||||||
|
|
||||||
fields = []
|
|
||||||
|
|
||||||
for variable in uritemplate.variables(url):
|
|
||||||
field = coreapi.Field(name=variable, location='path', required=True)
|
|
||||||
fields.append(field)
|
|
||||||
|
|
||||||
if method in ('put', 'patch', 'post'):
|
|
||||||
cls = view_instance.get_serializer_class()
|
|
||||||
serializer = cls()
|
|
||||||
for field in serializer.fields.values():
|
|
||||||
if field.read_only:
|
|
||||||
continue
|
|
||||||
required = field.required and method != 'patch'
|
|
||||||
field = coreapi.Field(name=field.source, location='form', required=required)
|
|
||||||
fields.append(field)
|
|
||||||
|
|
||||||
return coreapi.Link(url=url, action=method, fields=fields)
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultRouter(SimpleRouter):
|
class DefaultRouter(SimpleRouter):
|
||||||
"""
|
"""
|
||||||
|
@ -334,7 +276,7 @@ class DefaultRouter(SimpleRouter):
|
||||||
self.schema_title = kwargs.pop('schema_title', None)
|
self.schema_title = kwargs.pop('schema_title', None)
|
||||||
super(DefaultRouter, self).__init__(*args, **kwargs)
|
super(DefaultRouter, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def get_api_root_view(self):
|
def get_api_root_view(self, schema_urls=None):
|
||||||
"""
|
"""
|
||||||
Return a view to use as the API root.
|
Return a view to use as the API root.
|
||||||
"""
|
"""
|
||||||
|
@ -345,10 +287,10 @@ class DefaultRouter(SimpleRouter):
|
||||||
|
|
||||||
view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||||
|
|
||||||
if self.schema_title:
|
if schema_urls and self.schema_title:
|
||||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||||
view_renderers += [renderers.CoreJSONRenderer]
|
view_renderers += [renderers.CoreJSONRenderer]
|
||||||
router = self
|
schema_generator = SchemaGenerator(patterns=schema_urls)
|
||||||
|
|
||||||
class APIRoot(views.APIView):
|
class APIRoot(views.APIView):
|
||||||
_ignore_model_permissions = True
|
_ignore_model_permissions = True
|
||||||
|
@ -356,10 +298,9 @@ class DefaultRouter(SimpleRouter):
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
if request.accepted_renderer.format == 'corejson':
|
if request.accepted_renderer.format == 'corejson':
|
||||||
content = router.get_links(request)
|
schema = schema_generator.get_schema(request)
|
||||||
if not content:
|
if schema is None:
|
||||||
raise exceptions.PermissionDenied()
|
raise exceptions.PermissionDenied()
|
||||||
schema = coreapi.Document(title=router.schema_title, content=content)
|
|
||||||
return Response(schema)
|
return Response(schema)
|
||||||
|
|
||||||
ret = OrderedDict()
|
ret = OrderedDict()
|
||||||
|
@ -388,15 +329,13 @@ class DefaultRouter(SimpleRouter):
|
||||||
Generate the list of URL patterns, including a default root view
|
Generate the list of URL patterns, including a default root view
|
||||||
for the API, and appending `.json` style format suffixes.
|
for the API, and appending `.json` style format suffixes.
|
||||||
"""
|
"""
|
||||||
urls = []
|
urls = super(DefaultRouter, self).get_urls()
|
||||||
|
|
||||||
if self.include_root_view:
|
if self.include_root_view:
|
||||||
root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
|
view = self.get_api_root_view(schema_urls=urls)
|
||||||
|
root_url = url(r'^$', view, name=self.root_view_name)
|
||||||
urls.append(root_url)
|
urls.append(root_url)
|
||||||
|
|
||||||
default_urls = super(DefaultRouter, self).get_urls()
|
|
||||||
urls.extend(default_urls)
|
|
||||||
|
|
||||||
if self.include_format_suffixes:
|
if self.include_format_suffixes:
|
||||||
urls = format_suffix_patterns(urls)
|
urls = format_suffix_patterns(urls)
|
||||||
|
|
||||||
|
|
176
rest_framework/schemas.py
Normal file
176
rest_framework/schemas.py
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
import coreapi
|
||||||
|
import uritemplate
|
||||||
|
from django.conf import settings
|
||||||
|
from django.contrib.admindocs.views import simplify_regex
|
||||||
|
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
||||||
|
from django.utils import six
|
||||||
|
|
||||||
|
from rest_framework import exceptions
|
||||||
|
from rest_framework.request import clone_request
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaGenerator(object):
|
||||||
|
default_mapping = {
|
||||||
|
'get': 'read',
|
||||||
|
'post': 'create',
|
||||||
|
'put': 'update',
|
||||||
|
'patch': 'partial_update',
|
||||||
|
'delete': 'destroy',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, schema_title=None, patterns=None, urlconf=None):
|
||||||
|
if patterns is None and urlconf is not None:
|
||||||
|
if isinstance(urlconf, six.string_types):
|
||||||
|
urls = import_module(urlconf)
|
||||||
|
else:
|
||||||
|
urls = urlconf
|
||||||
|
patterns = urls.urlpatterns
|
||||||
|
elif patterns is None and urlconf is None:
|
||||||
|
urls = import_module(settings.ROOT_URLCONF)
|
||||||
|
patterns = urls.urlpatterns
|
||||||
|
|
||||||
|
self.schema_title = schema_title
|
||||||
|
self.endpoints = self.get_api_endpoints(patterns)
|
||||||
|
|
||||||
|
def get_schema(self, request=None):
|
||||||
|
if request is None:
|
||||||
|
endpoints = self.endpoints
|
||||||
|
else:
|
||||||
|
# Filter the list of endpoints to only include those that
|
||||||
|
# the user has permission on.
|
||||||
|
endpoints = []
|
||||||
|
for key, link, callback in self.endpoints:
|
||||||
|
method = link.action.upper()
|
||||||
|
view = callback.cls()
|
||||||
|
view.request = clone_request(request, method)
|
||||||
|
try:
|
||||||
|
view.check_permissions(view.request)
|
||||||
|
except exceptions.APIException:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
endpoints.append((key, link, callback))
|
||||||
|
|
||||||
|
if not endpoints:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate the schema content structure, from the endpoints.
|
||||||
|
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
||||||
|
content = {}
|
||||||
|
for key, link, callback in endpoints:
|
||||||
|
insert_into = content
|
||||||
|
for item in key[:1]:
|
||||||
|
if item not in insert_into:
|
||||||
|
insert_into[item] = {}
|
||||||
|
insert_into = insert_into[item]
|
||||||
|
insert_into[key[-1]] = link
|
||||||
|
|
||||||
|
# Return the schema document.
|
||||||
|
return coreapi.Document(title=self.schema_title, content=content)
|
||||||
|
|
||||||
|
def get_api_endpoints(self, patterns, prefix=''):
|
||||||
|
"""
|
||||||
|
Return a list of all available API endpoints by inspecting the URL conf.
|
||||||
|
"""
|
||||||
|
api_endpoints = []
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
path_regex = prefix + pattern.regex.pattern
|
||||||
|
|
||||||
|
if isinstance(pattern, RegexURLPattern):
|
||||||
|
path = self.get_path(path_regex)
|
||||||
|
callback = pattern.callback
|
||||||
|
if self.include_endpoint(path, callback):
|
||||||
|
for method in self.get_allowed_methods(callback):
|
||||||
|
key = self.get_key(path, method, callback)
|
||||||
|
link = self.get_link(path, method, callback)
|
||||||
|
endpoint = (key, link, callback)
|
||||||
|
api_endpoints.append(endpoint)
|
||||||
|
|
||||||
|
elif isinstance(pattern, RegexURLResolver):
|
||||||
|
nested_endpoints = self.get_api_endpoints(
|
||||||
|
patterns=pattern.url_patterns,
|
||||||
|
prefix=path_regex
|
||||||
|
)
|
||||||
|
api_endpoints.extend(nested_endpoints)
|
||||||
|
|
||||||
|
return api_endpoints
|
||||||
|
|
||||||
|
def get_path(self, path_regex):
|
||||||
|
"""
|
||||||
|
Given a URL conf regex, return a URI template string.
|
||||||
|
"""
|
||||||
|
path = simplify_regex(path_regex)
|
||||||
|
path = path.replace('<', '{').replace('>', '}')
|
||||||
|
return path
|
||||||
|
|
||||||
|
def include_endpoint(self, path, callback):
|
||||||
|
"""
|
||||||
|
Return True if the given endpoint should be included.
|
||||||
|
"""
|
||||||
|
cls = getattr(callback, 'cls', None)
|
||||||
|
if (cls is None) or not issubclass(cls, APIView):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if path == '/':
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_allowed_methods(self, callback):
|
||||||
|
"""
|
||||||
|
Return a list of the valid HTTP methods for this endpoint.
|
||||||
|
"""
|
||||||
|
if hasattr(callback, 'actions'):
|
||||||
|
return [method.upper() for method in callback.actions.keys()]
|
||||||
|
|
||||||
|
return [
|
||||||
|
method for method in
|
||||||
|
callback.cls().allowed_methods if method != 'OPTIONS'
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_key(self, path, method, callback):
|
||||||
|
"""
|
||||||
|
Return a tuple of strings, indicating the identity to use for a
|
||||||
|
given endpoint. eg. ('users', 'list').
|
||||||
|
"""
|
||||||
|
category = None
|
||||||
|
for item in path.strip('/').split('/'):
|
||||||
|
if '{' in item:
|
||||||
|
break
|
||||||
|
category = item
|
||||||
|
|
||||||
|
actions = getattr(callback, 'actions', self.default_mapping)
|
||||||
|
action = actions[method.lower()]
|
||||||
|
|
||||||
|
if category:
|
||||||
|
return (category, action)
|
||||||
|
return (action,)
|
||||||
|
|
||||||
|
def get_link(self, path, method, callback):
|
||||||
|
"""
|
||||||
|
Return a `coreapi.Link` instance for the given endpoint.
|
||||||
|
"""
|
||||||
|
view = callback.cls()
|
||||||
|
fields = []
|
||||||
|
|
||||||
|
for variable in uritemplate.variables(path):
|
||||||
|
field = coreapi.Field(name=variable, location='path', required=True)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
if method in ('PUT', 'PATCH', 'POST'):
|
||||||
|
serializer_class = view.get_serializer_class()
|
||||||
|
serializer = serializer_class()
|
||||||
|
for field in serializer.fields.values():
|
||||||
|
if field.read_only:
|
||||||
|
continue
|
||||||
|
required = field.required and method != 'PATCH'
|
||||||
|
field = coreapi.Field(name=field.source, location='form', required=required)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
return coreapi.Link(url=path, action=method.lower(), fields=fields)
|
|
@ -98,6 +98,7 @@ class ViewSetMixin(object):
|
||||||
# resolved URL.
|
# resolved URL.
|
||||||
view.cls = cls
|
view.cls = cls
|
||||||
view.suffix = initkwargs.get('suffix', None)
|
view.suffix = initkwargs.get('suffix', None)
|
||||||
|
view.actions = actions
|
||||||
return csrf_exempt(view)
|
return csrf_exempt(view)
|
||||||
|
|
||||||
def initialize_request(self, request, *args, **kwargs):
|
def initialize_request(self, request, *args, **kwargs):
|
||||||
|
|
|
@ -257,7 +257,7 @@ class TestNameableRoot(TestCase):
|
||||||
|
|
||||||
def test_router_has_custom_name(self):
|
def test_router_has_custom_name(self):
|
||||||
expected = 'nameable-root'
|
expected = 'nameable-root'
|
||||||
self.assertEqual(expected, self.urls[0].name)
|
self.assertEqual(expected, self.urls[-1].name)
|
||||||
|
|
||||||
|
|
||||||
class TestActionKeywordArgs(TestCase):
|
class TestActionKeywordArgs(TestCase):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user