mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 22:04:48 +03:00
Added SchemaGenerator class
This commit is contained in:
parent
474a23e254
commit
482289695d
|
@ -18,16 +18,15 @@ from __future__ import unicode_literals
|
|||
import itertools
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
import uritemplate
|
||||
from django.conf.urls import url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
|
||||
from rest_framework import exceptions, renderers, views
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.request import override_method
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
|
@ -263,63 +262,6 @@ class SimpleRouter(BaseRouter):
|
|||
|
||||
return ret
|
||||
|
||||
def get_links(self, request=None):
|
||||
content = {}
|
||||
|
||||
for prefix, viewset, basename in self.registry:
|
||||
lookup_field = getattr(viewset, 'lookup_field', 'pk')
|
||||
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
|
||||
lookup_placeholder = '{' + lookup_url_kwarg + '}'
|
||||
|
||||
routes = self.get_routes(viewset)
|
||||
|
||||
for route in routes:
|
||||
url = '/' + route.url.format(
|
||||
prefix=prefix,
|
||||
lookup=lookup_placeholder,
|
||||
trailing_slash=self.trailing_slash
|
||||
).lstrip('^').rstrip('$')
|
||||
|
||||
mapping = self.get_method_map(viewset, route.mapping)
|
||||
if not mapping:
|
||||
continue
|
||||
|
||||
for method, action in mapping.items():
|
||||
link = self.get_link(viewset, url, method, request)
|
||||
if link is None:
|
||||
continue # User does not have permissions.
|
||||
if prefix not in content:
|
||||
content[prefix] = {}
|
||||
content[prefix][action] = link
|
||||
return content
|
||||
|
||||
def get_link(self, viewset, url, method, request=None):
|
||||
view_instance = viewset()
|
||||
if request is not None:
|
||||
with override_method(view_instance, request, method.upper()) as request:
|
||||
try:
|
||||
view_instance.check_permissions(request)
|
||||
except exceptions.APIException:
|
||||
return None
|
||||
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(url):
|
||||
field = coreapi.Field(name=variable, location='path', required=True)
|
||||
fields.append(field)
|
||||
|
||||
if method in ('put', 'patch', 'post'):
|
||||
cls = view_instance.get_serializer_class()
|
||||
serializer = cls()
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only:
|
||||
continue
|
||||
required = field.required and method != 'patch'
|
||||
field = coreapi.Field(name=field.source, location='form', required=required)
|
||||
fields.append(field)
|
||||
|
||||
return coreapi.Link(url=url, action=method, fields=fields)
|
||||
|
||||
|
||||
class DefaultRouter(SimpleRouter):
|
||||
"""
|
||||
|
@ -334,7 +276,7 @@ class DefaultRouter(SimpleRouter):
|
|||
self.schema_title = kwargs.pop('schema_title', None)
|
||||
super(DefaultRouter, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_api_root_view(self):
|
||||
def get_api_root_view(self, schema_urls=None):
|
||||
"""
|
||||
Return a view to use as the API root.
|
||||
"""
|
||||
|
@ -345,10 +287,10 @@ class DefaultRouter(SimpleRouter):
|
|||
|
||||
view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||
|
||||
if self.schema_title:
|
||||
if schema_urls and self.schema_title:
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
view_renderers += [renderers.CoreJSONRenderer]
|
||||
router = self
|
||||
schema_generator = SchemaGenerator(patterns=schema_urls)
|
||||
|
||||
class APIRoot(views.APIView):
|
||||
_ignore_model_permissions = True
|
||||
|
@ -356,10 +298,9 @@ class DefaultRouter(SimpleRouter):
|
|||
|
||||
def get(self, request, *args, **kwargs):
|
||||
if request.accepted_renderer.format == 'corejson':
|
||||
content = router.get_links(request)
|
||||
if not content:
|
||||
schema = schema_generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
schema = coreapi.Document(title=router.schema_title, content=content)
|
||||
return Response(schema)
|
||||
|
||||
ret = OrderedDict()
|
||||
|
@ -388,15 +329,13 @@ class DefaultRouter(SimpleRouter):
|
|||
Generate the list of URL patterns, including a default root view
|
||||
for the API, and appending `.json` style format suffixes.
|
||||
"""
|
||||
urls = []
|
||||
urls = super(DefaultRouter, self).get_urls()
|
||||
|
||||
if self.include_root_view:
|
||||
root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
|
||||
view = self.get_api_root_view(schema_urls=urls)
|
||||
root_url = url(r'^$', view, name=self.root_view_name)
|
||||
urls.append(root_url)
|
||||
|
||||
default_urls = super(DefaultRouter, self).get_urls()
|
||||
urls.extend(default_urls)
|
||||
|
||||
if self.include_format_suffixes:
|
||||
urls = format_suffix_patterns(urls)
|
||||
|
||||
|
|
176
rest_framework/schemas.py
Normal file
176
rest_framework/schemas.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
from importlib import import_module
|
||||
|
||||
import coreapi
|
||||
import uritemplate
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
||||
from django.utils import six
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
def __init__(self, schema_title=None, patterns=None, urlconf=None):
|
||||
if patterns is None and urlconf is not None:
|
||||
if isinstance(urlconf, six.string_types):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
elif patterns is None and urlconf is None:
|
||||
urls = import_module(settings.ROOT_URLCONF)
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.schema_title = schema_title
|
||||
self.endpoints = self.get_api_endpoints(patterns)
|
||||
|
||||
def get_schema(self, request=None):
|
||||
if request is None:
|
||||
endpoints = self.endpoints
|
||||
else:
|
||||
# Filter the list of endpoints to only include those that
|
||||
# the user has permission on.
|
||||
endpoints = []
|
||||
for key, link, callback in self.endpoints:
|
||||
method = link.action.upper()
|
||||
view = callback.cls()
|
||||
view.request = clone_request(request, method)
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
pass
|
||||
else:
|
||||
endpoints.append((key, link, callback))
|
||||
|
||||
if not endpoints:
|
||||
return None
|
||||
|
||||
# Generate the schema content structure, from the endpoints.
|
||||
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
||||
content = {}
|
||||
for key, link, callback in endpoints:
|
||||
insert_into = content
|
||||
for item in key[:1]:
|
||||
if item not in insert_into:
|
||||
insert_into[item] = {}
|
||||
insert_into = insert_into[item]
|
||||
insert_into[key[-1]] = link
|
||||
|
||||
# Return the schema document.
|
||||
return coreapi.Document(title=self.schema_title, content=content)
|
||||
|
||||
def get_api_endpoints(self, patterns, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + pattern.regex.pattern
|
||||
|
||||
if isinstance(pattern, RegexURLPattern):
|
||||
path = self.get_path(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
key = self.get_key(path, method, callback)
|
||||
link = self.get_link(path, method, callback)
|
||||
endpoint = (key, link, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
nested_endpoints = self.get_api_endpoints(
|
||||
patterns=pattern.url_patterns,
|
||||
prefix=path_regex
|
||||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
return api_endpoints
|
||||
|
||||
def get_path(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
path = simplify_regex(path_regex)
|
||||
path = path.replace('<', '{').replace('>', '}')
|
||||
return path
|
||||
|
||||
def include_endpoint(self, path, callback):
|
||||
"""
|
||||
Return True if the given endpoint should be included.
|
||||
"""
|
||||
cls = getattr(callback, 'cls', None)
|
||||
if (cls is None) or not issubclass(cls, APIView):
|
||||
return False
|
||||
|
||||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False
|
||||
|
||||
if path == '/':
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
"""
|
||||
Return a list of the valid HTTP methods for this endpoint.
|
||||
"""
|
||||
if hasattr(callback, 'actions'):
|
||||
return [method.upper() for method in callback.actions.keys()]
|
||||
|
||||
return [
|
||||
method for method in
|
||||
callback.cls().allowed_methods if method != 'OPTIONS'
|
||||
]
|
||||
|
||||
def get_key(self, path, method, callback):
|
||||
"""
|
||||
Return a tuple of strings, indicating the identity to use for a
|
||||
given endpoint. eg. ('users', 'list').
|
||||
"""
|
||||
category = None
|
||||
for item in path.strip('/').split('/'):
|
||||
if '{' in item:
|
||||
break
|
||||
category = item
|
||||
|
||||
actions = getattr(callback, 'actions', self.default_mapping)
|
||||
action = actions[method.lower()]
|
||||
|
||||
if category:
|
||||
return (category, action)
|
||||
return (action,)
|
||||
|
||||
def get_link(self, path, method, callback):
|
||||
"""
|
||||
Return a `coreapi.Link` instance for the given endpoint.
|
||||
"""
|
||||
view = callback.cls()
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
field = coreapi.Field(name=variable, location='path', required=True)
|
||||
fields.append(field)
|
||||
|
||||
if method in ('PUT', 'PATCH', 'POST'):
|
||||
serializer_class = view.get_serializer_class()
|
||||
serializer = serializer_class()
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only:
|
||||
continue
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(name=field.source, location='form', required=required)
|
||||
fields.append(field)
|
||||
|
||||
return coreapi.Link(url=path, action=method.lower(), fields=fields)
|
|
@ -98,6 +98,7 @@ class ViewSetMixin(object):
|
|||
# resolved URL.
|
||||
view.cls = cls
|
||||
view.suffix = initkwargs.get('suffix', None)
|
||||
view.actions = actions
|
||||
return csrf_exempt(view)
|
||||
|
||||
def initialize_request(self, request, *args, **kwargs):
|
||||
|
|
|
@ -257,7 +257,7 @@ class TestNameableRoot(TestCase):
|
|||
|
||||
def test_router_has_custom_name(self):
|
||||
expected = 'nameable-root'
|
||||
self.assertEqual(expected, self.urls[0].name)
|
||||
self.assertEqual(expected, self.urls[-1].name)
|
||||
|
||||
|
||||
class TestActionKeywordArgs(TestCase):
|
||||
|
|
Loading…
Reference in New Issue
Block a user