From c3a9538ad90b9236ec91ec19e149075087b30765 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Sep 2016 13:29:01 +0100 Subject: [PATCH] Clean up schema generation (#4527) --- rest_framework/schemas.py | 291 ++++++++++++++++++++++---------------- rest_framework/views.py | 1 + tests/test_schemas.py | 69 +++++---- 3 files changed, 212 insertions(+), 149 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index f2ec1f9e1..39dd8d910 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -32,85 +32,66 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -class SchemaGenerator(object): - default_mapping = { - 'get': 'read', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - known_actions = ( - 'create', 'read', 'retrieve', 'list', - 'update', 'partial_update', 'destroy' - ) +def insert_into(target, keys, value): + """ + Nested dictionary insertion. - def __init__(self, title=None, url=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + {'a': {'b': {'c': 123}}} + """ + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value - if patterns is None and urlconf is not None: - if isinstance(urlconf, six.string_types): - urls = import_module(urlconf) + +def is_custom_action(action): + return action not in set([ + 'read', 'retrieve', 'list', + 'create', 'update', 'partial_update', 'delete', 'destroy' + ]) + + +class EndpointInspector(object): + """ + A class to determine the available API endpoints that a project exposes. + """ + def __init__(self, patterns=None, urlconf=None): + if patterns is None: + if urlconf is None: + # Use the default Django URL conf + urls = import_module(settings.ROOT_URLCONF) + patterns = urls.urlpatterns else: - urls = urlconf - self.patterns = urls.urlpatterns - elif patterns is None and urlconf is None: - urls = import_module(settings.ROOT_URLCONF) - self.patterns = urls.urlpatterns - else: - self.patterns = patterns + # Load the given URLconf module + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) + else: + urls = urlconf + patterns = urls.urlpatterns - if url and not url.endswith('/'): - url += '/' + self.patterns = patterns - self.title = title - self.url = url - self.endpoints = None - - def get_schema(self, request=None): - if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) - - links = [] - for path, method, category, action, callback in self.endpoints: - view = self.setup_view(callback, method, request) - if self.should_include_link(path, method, callback, view): - link = self.get_link(path, method, callback, view) - links.append((category, action, link)) - - if not links: - return None - - # Generate the schema content structure, eg: - # {'users': {'list': Link()}} - content = {} - for category, action, link in links: - if category is None: - content[action] = link - elif category in content: - content[category][action] = link - else: - content[category] = {action: link} - - # Return the schema document. - return coreapi.Document(title=self.title, content=content, url=self.url) - - def get_api_endpoints(self, patterns, prefix=''): + def get_api_endpoints(self, patterns=None, prefix=''): """ Return a list of all available API endpoints by inspecting the URL conf. """ + if patterns is None: + patterns = self.patterns + api_endpoints = [] for pattern in patterns: path_regex = prefix + pattern.regex.pattern if isinstance(pattern, RegexURLPattern): - path = self.get_path(path_regex) + path = self.get_path_from_regex(path_regex) callback = pattern.callback if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): - action = self.get_action(path, method, callback) - category = self.get_category(path, method, callback, action) - endpoint = (path, method, category, action, callback) + endpoint = (path, method, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -122,7 +103,7 @@ class SchemaGenerator(object): return api_endpoints - def get_path(self, path_regex): + def get_path_from_regex(self, path_regex): """ Given a URL conf regex, return a URI template string. """ @@ -157,47 +138,60 @@ class SchemaGenerator(object): callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') ] - def get_action(self, path, method, callback): - """ - Return a descriptive action string for the endpoint, eg. 'list'. - """ - actions = getattr(callback, 'actions', self.default_mapping) - return actions[method.lower()] - def get_category(self, path, method, callback, action): +class SchemaGenerator(object): + # Map methods onto 'actions' that are the names used in the link layout. + default_mapping = { + 'get': 'read', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + # Coerce the following viewset actions into different names. + coerce_actions = { + 'retrieve': 'read', + 'destroy': 'delete' + } + endpoint_inspector_cls = EndpointInspector + + def __init__(self, title=None, url=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + + if url and not url.endswith('/'): + url += '/' + + self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf) + self.title = title + self.url = url + self.endpoints = None + + def get_schema(self, request=None): """ - Return a descriptive category string for the endpoint, eg. 'users'. - - Examples of category/action pairs that should be generated for various - endpoints: - - /users/ [users][list], [users][create] - /users/{pk}/ [users][read], [users][update], [users][destroy] - /users/enabled/ [users][enabled] (custom action) - /users/{pk}/star/ [users][star] (custom action) - /users/{pk}/groups/ [groups][list], [groups][create] - /users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy] + Generate a `coreapi.Document` representing the API schema. """ - path_components = path.strip('/').split('/') - path_components = [ - component for component in path_components - if '{' not in component - ] - if action in self.known_actions: - # Default action, eg "/users/", "/users/{pk}/" - idx = -1 - else: - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - idx = -2 + if self.endpoints is None: + self.endpoints = self.endpoint_inspector.get_api_endpoints() - try: - return path_components[idx] - except IndexError: + links = {} + for path, method, callback in self.endpoints: + view = self.create_view(callback, method, request) + if not self.has_view_permissions(view): + continue + link = self.get_link(path, method, view) + keys = self.get_keys(path, method, view) + insert_into(links, keys, link) + + if not links: return None - def setup_view(self, callback, method, request): + return coreapi.Document(title=self.title, url=self.url, content=links) + + # Methods used when we generate a view instance from the raw callback... + + def create_view(self, callback, method, request=None): """ - Setup a view instance. + Given a callback, return an actual view instance. """ view = callback.cls() for attr, val in getattr(callback, 'initkwargs', {}).items(): @@ -205,6 +199,7 @@ class SchemaGenerator(object): view.args = () view.kwargs = {} view.format_kwarg = None + view.request = None actions = getattr(callback, 'actions', None) if actions is not None: @@ -215,14 +210,13 @@ class SchemaGenerator(object): if request is not None: view.request = clone_request(request, method) - else: - view.request = None return view - # Methods for generating each individual `Link` instance... - - def should_include_link(self, path, method, callback, view): + def has_view_permissions(self, view): + """ + Return `True` if the incoming request has the correct view permissions. + """ if view.request is None: return True @@ -230,20 +224,35 @@ class SchemaGenerator(object): view.check_permissions(view.request) except exceptions.APIException: return False - return True - def get_link(self, path, method, callback, view): + def is_list_endpoint(self, path, method, view): + """ + Return True if the given path/method appears to represent a list endpoint. + """ + if hasattr(view, 'action'): + return view.action == 'list' + + if method.lower() != 'get': + return False + path_components = path.strip('/').split('/') + if path_components and '{' in path_components[-1]: + return False + return True + + # Methods for generating each individual `Link` instance... + + def get_link(self, path, method, view): """ Return a `coreapi.Link` instance for the given endpoint. """ - fields = self.get_path_fields(path, method, callback, view) - fields += self.get_serializer_fields(path, method, callback, view) - fields += self.get_pagination_fields(path, method, callback, view) - fields += self.get_filter_fields(path, method, callback, view) + fields = self.get_path_fields(path, method, view) + fields += self.get_serializer_fields(path, method, view) + fields += self.get_pagination_fields(path, method, view) + fields += self.get_filter_fields(path, method, view) if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method, callback, view) + encoding = self.get_encoding(path, method, view) else: encoding = None @@ -257,7 +266,7 @@ class SchemaGenerator(object): fields=fields ) - def get_encoding(self, path, method, callback, view): + def get_encoding(self, path, method, view): """ Return the 'encoding' parameter to use for a given endpoint. """ @@ -278,7 +287,7 @@ class SchemaGenerator(object): return None - def get_path_fields(self, path, method, callback, view): + def get_path_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any templated path variables. @@ -291,7 +300,7 @@ class SchemaGenerator(object): return fields - def get_serializer_fields(self, path, method, callback, view): + def get_serializer_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any request body input, as determined by the serializer class. @@ -327,11 +336,8 @@ class SchemaGenerator(object): return fields - def get_pagination_fields(self, path, method, callback, view): - if method != 'GET': - return [] - - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): + def get_pagination_fields(self, path, method, view): + if not self.is_list_endpoint(path, method, view): return [] if not getattr(view, 'pagination_class', None): @@ -340,17 +346,54 @@ class SchemaGenerator(object): paginator = view.pagination_class() return as_query_fields(paginator.get_fields(view)) - def get_filter_fields(self, path, method, callback, view): - if method != 'GET': + def get_filter_fields(self, path, method, view): + if not self.is_list_endpoint(path, method, view): return [] - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): - return [] - - if not hasattr(view, 'filter_backends'): + if not getattr(view, 'filter_backends', None): return [] fields = [] for filter_backend in view.filter_backends: fields += as_query_fields(filter_backend().get_fields(view)) return fields + + # Method for generating the link layout.... + + def get_keys(self, path, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + if view.action in self.coerce_actions: + action = self.coerce_actions[view.action] + else: + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if self.is_list_endpoint(path, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in path.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + return named_path_components[:-1] + [action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] diff --git a/rest_framework/views.py b/rest_framework/views.py index 15d8c6cde..23d48962c 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -130,6 +130,7 @@ class APIView(View): view = super(APIView, cls).as_view(**initkwargs) view.cls = cls + view.initkwargs = initkwargs # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index dc01d8cd8..1d98a4618 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -6,7 +6,6 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi from rest_framework.decorators import detail_route, list_route -from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator from rest_framework.test import APIClient @@ -55,24 +54,11 @@ class ExampleViewSet(ModelViewSet): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) -class ExampleView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, request, *args, **kwargs): - return Response() - - def post(self, request, *args, **kwargs): - return Response() - - router = DefaultRouter(schema_title='Example API' if coreapi else None) router.register('example', ExampleViewSet, base_name='example') urlpatterns = [ url(r'^', include(router.urls)) ] -urlpatterns2 = [ - url(r'^example-view/$', ExampleView.as_view(), name='example-view') -] @unittest.skipUnless(coreapi, 'coreapi is not installed') @@ -99,7 +85,7 @@ class TestRouterGeneratedSchema(TestCase): url='/example/custom_list_action/', action='get' ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -138,7 +124,7 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('b', required=False, location='form') ] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -179,7 +165,7 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('b', required=False, location='form') ] ), - 'destroy': coreapi.Link( + 'delete': coreapi.Link( url='/example/{pk}/', action='delete', fields=[ @@ -192,25 +178,58 @@ class TestRouterGeneratedSchema(TestCase): self.assertEqual(response.data, expected) +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): - def test_view(self): - schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2) - schema = schema_generator.get_schema() + def setUp(self): + self.patterns = [ + url('^example/?$', ExampleListView.as_view()), + url('^example/(?P\d+)/?$', ExampleDetailView.as_view()), + ] + + def test_schema_for_regular_views(self): + """ + Ensure that schema generation works for APIView classes. + """ + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() expected = coreapi.Document( url='', - title='Test View', + title='Example API', content={ - 'example-view': { + 'example': { 'create': coreapi.Link( - url='/example-view/', + url='/example/', action='post', fields=[] ), - 'read': coreapi.Link( - url='/example-view/', + 'list': coreapi.Link( + url='/example/', action='get', fields=[] + ), + 'read': coreapi.Link( + url='/example/{pk}/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] ) } }