diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 7e5bf4f84..3bd1fc850 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -18,16 +18,15 @@ from __future__ import unicode_literals import itertools from collections import OrderedDict, namedtuple -import uritemplate from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views from rest_framework.compat import coreapi -from rest_framework.request import override_method from rest_framework.response import Response from rest_framework.reverse import reverse +from rest_framework.schemas import SchemaGenerator from rest_framework.settings import api_settings from rest_framework.urlpatterns import format_suffix_patterns @@ -263,63 +262,6 @@ class SimpleRouter(BaseRouter): return ret - def get_links(self, request=None): - content = {} - - for prefix, viewset, basename in self.registry: - lookup_field = getattr(viewset, 'lookup_field', 'pk') - lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field - lookup_placeholder = '{' + lookup_url_kwarg + '}' - - routes = self.get_routes(viewset) - - for route in routes: - url = '/' + route.url.format( - prefix=prefix, - lookup=lookup_placeholder, - trailing_slash=self.trailing_slash - ).lstrip('^').rstrip('$') - - mapping = self.get_method_map(viewset, route.mapping) - if not mapping: - continue - - for method, action in mapping.items(): - link = self.get_link(viewset, url, method, request) - if link is None: - continue # User does not have permissions. - if prefix not in content: - content[prefix] = {} - content[prefix][action] = link - return content - - def get_link(self, viewset, url, method, request=None): - view_instance = viewset() - if request is not None: - with override_method(view_instance, request, method.upper()) as request: - try: - view_instance.check_permissions(request) - except exceptions.APIException: - return None - - fields = [] - - for variable in uritemplate.variables(url): - field = coreapi.Field(name=variable, location='path', required=True) - fields.append(field) - - if method in ('put', 'patch', 'post'): - cls = view_instance.get_serializer_class() - serializer = cls() - for field in serializer.fields.values(): - if field.read_only: - continue - required = field.required and method != 'patch' - field = coreapi.Field(name=field.source, location='form', required=required) - fields.append(field) - - return coreapi.Link(url=url, action=method, fields=fields) - class DefaultRouter(SimpleRouter): """ @@ -334,7 +276,7 @@ class DefaultRouter(SimpleRouter): self.schema_title = kwargs.pop('schema_title', None) super(DefaultRouter, self).__init__(*args, **kwargs) - def get_api_root_view(self): + def get_api_root_view(self, schema_urls=None): """ Return a view to use as the API root. """ @@ -345,10 +287,10 @@ class DefaultRouter(SimpleRouter): view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) - if self.schema_title: + if schema_urls and self.schema_title: assert coreapi, '`coreapi` must be installed for schema support.' view_renderers += [renderers.CoreJSONRenderer] - router = self + schema_generator = SchemaGenerator(patterns=schema_urls) class APIRoot(views.APIView): _ignore_model_permissions = True @@ -356,10 +298,9 @@ class DefaultRouter(SimpleRouter): def get(self, request, *args, **kwargs): if request.accepted_renderer.format == 'corejson': - content = router.get_links(request) - if not content: + schema = schema_generator.get_schema(request) + if schema is None: raise exceptions.PermissionDenied() - schema = coreapi.Document(title=router.schema_title, content=content) return Response(schema) ret = OrderedDict() @@ -388,15 +329,13 @@ class DefaultRouter(SimpleRouter): Generate the list of URL patterns, including a default root view for the API, and appending `.json` style format suffixes. """ - urls = [] + urls = super(DefaultRouter, self).get_urls() if self.include_root_view: - root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) + view = self.get_api_root_view(schema_urls=urls) + root_url = url(r'^$', view, name=self.root_view_name) urls.append(root_url) - default_urls = super(DefaultRouter, self).get_urls() - urls.extend(default_urls) - if self.include_format_suffixes: urls = format_suffix_patterns(urls) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py new file mode 100644 index 000000000..c47ff3eb7 --- /dev/null +++ b/rest_framework/schemas.py @@ -0,0 +1,176 @@ +from importlib import import_module + +import coreapi +import uritemplate +from django.conf import settings +from django.contrib.admindocs.views import simplify_regex +from django.core.urlresolvers import RegexURLPattern, RegexURLResolver +from django.utils import six + +from rest_framework import exceptions +from rest_framework.request import clone_request +from rest_framework.views import APIView + + +class SchemaGenerator(object): + default_mapping = { + 'get': 'read', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + def __init__(self, schema_title=None, patterns=None, urlconf=None): + if patterns is None and urlconf is not None: + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) + else: + urls = urlconf + patterns = urls.urlpatterns + elif patterns is None and urlconf is None: + urls = import_module(settings.ROOT_URLCONF) + patterns = urls.urlpatterns + + self.schema_title = schema_title + self.endpoints = self.get_api_endpoints(patterns) + + def get_schema(self, request=None): + if request is None: + endpoints = self.endpoints + else: + # Filter the list of endpoints to only include those that + # the user has permission on. + endpoints = [] + for key, link, callback in self.endpoints: + method = link.action.upper() + view = callback.cls() + view.request = clone_request(request, method) + try: + view.check_permissions(view.request) + except exceptions.APIException: + pass + else: + endpoints.append((key, link, callback)) + + if not endpoints: + return None + + # Generate the schema content structure, from the endpoints. + # ('users', 'list'), Link -> {'users': {'list': Link()}} + content = {} + for key, link, callback in endpoints: + insert_into = content + for item in key[:1]: + if item not in insert_into: + insert_into[item] = {} + insert_into = insert_into[item] + insert_into[key[-1]] = link + + # Return the schema document. + return coreapi.Document(title=self.schema_title, content=content) + + def get_api_endpoints(self, patterns, prefix=''): + """ + Return a list of all available API endpoints by inspecting the URL conf. + """ + api_endpoints = [] + + for pattern in patterns: + path_regex = prefix + pattern.regex.pattern + + if isinstance(pattern, RegexURLPattern): + path = self.get_path(path_regex) + callback = pattern.callback + if self.include_endpoint(path, callback): + for method in self.get_allowed_methods(callback): + key = self.get_key(path, method, callback) + link = self.get_link(path, method, callback) + endpoint = (key, link, callback) + api_endpoints.append(endpoint) + + elif isinstance(pattern, RegexURLResolver): + nested_endpoints = self.get_api_endpoints( + patterns=pattern.url_patterns, + prefix=path_regex + ) + api_endpoints.extend(nested_endpoints) + + return api_endpoints + + def get_path(self, path_regex): + """ + Given a URL conf regex, return a URI template string. + """ + path = simplify_regex(path_regex) + path = path.replace('<', '{').replace('>', '}') + return path + + def include_endpoint(self, path, callback): + """ + Return True if the given endpoint should be included. + """ + cls = getattr(callback, 'cls', None) + if (cls is None) or not issubclass(cls, APIView): + return False + + if path.endswith('.{format}') or path.endswith('.{format}/'): + return False + + if path == '/': + return False + + return True + + def get_allowed_methods(self, callback): + """ + Return a list of the valid HTTP methods for this endpoint. + """ + if hasattr(callback, 'actions'): + return [method.upper() for method in callback.actions.keys()] + + return [ + method for method in + callback.cls().allowed_methods if method != 'OPTIONS' + ] + + def get_key(self, path, method, callback): + """ + Return a tuple of strings, indicating the identity to use for a + given endpoint. eg. ('users', 'list'). + """ + category = None + for item in path.strip('/').split('/'): + if '{' in item: + break + category = item + + actions = getattr(callback, 'actions', self.default_mapping) + action = actions[method.lower()] + + if category: + return (category, action) + return (action,) + + def get_link(self, path, method, callback): + """ + Return a `coreapi.Link` instance for the given endpoint. + """ + view = callback.cls() + fields = [] + + for variable in uritemplate.variables(path): + field = coreapi.Field(name=variable, location='path', required=True) + fields.append(field) + + if method in ('PUT', 'PATCH', 'POST'): + serializer_class = view.get_serializer_class() + serializer = serializer_class() + for field in serializer.fields.values(): + if field.read_only: + continue + required = field.required and method != 'PATCH' + field = coreapi.Field(name=field.source, location='form', required=required) + fields.append(field) + + return coreapi.Link(url=path, action=method.lower(), fields=fields) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 05434b72e..7687448c4 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -98,6 +98,7 @@ class ViewSetMixin(object): # resolved URL. view.cls = cls view.suffix = initkwargs.get('suffix', None) + view.actions = actions return csrf_exempt(view) def initialize_request(self, request, *args, **kwargs): diff --git a/tests/test_routers.py b/tests/test_routers.py index acab660d8..f45039f80 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -257,7 +257,7 @@ class TestNameableRoot(TestCase): def test_router_has_custom_name(self): expected = 'nameable-root' - self.assertEqual(expected, self.urls[0].name) + self.assertEqual(expected, self.urls[-1].name) class TestActionKeywordArgs(TestCase):