diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 28f5dc8a1..c3a811bfb 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -30,24 +30,6 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -def insert_into(target, keys, item): - """ - Insert `item` into the nested dictionary `target`. - - For example: - - target = {} - insert_into(target, ('users', 'list'), Link(...)) - insert_into(target, ('users', 'detail'), Link(...)) - assert target == {'users': {'list': Link(...), 'detail': Link(...)}} - """ - for key in keys[:1]: - if key not in target: - target[key] = {} - target = target[key] - target[keys[-1]] = item - - class SchemaGenerator(object): default_mapping = { 'get': 'read', @@ -84,7 +66,7 @@ class SchemaGenerator(object): self.endpoints = self.get_api_endpoints(self.patterns) links = [] - for key, path, method, callback in self.endpoints: + for path, method, category, action, callback in self.endpoints: view = callback.cls() for attr, val in getattr(callback, 'initkwargs', {}).items(): setattr(view, attr, val) @@ -102,16 +84,21 @@ class SchemaGenerator(object): view.request = None link = self.get_link(path, method, callback, view) - links.append((key, link)) + links.append((category, action, link)) - if not link: + if not links: return None - # Generate the schema content structure, from the endpoints. - # ('users', 'list'), Link -> {'users': {'list': Link()}} + # Generate the schema content structure, eg: + # {'users': {'list': Link()}} content = {} - for key, link in links: - insert_into(content, key, link) + for category, action, link in links: + if category is None: + content[action] = link + elif category in content: + content[category][action] = link + else: + content[category] = {action: link} # Return the schema document. return coreapi.Document(title=self.title, content=content, url=self.url) @@ -129,8 +116,8 @@ class SchemaGenerator(object): callback = pattern.callback if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): - key = self.get_key(path, method, callback) - endpoint = (key, path, method, callback) + action = self.get_action(path, method, callback) + endpoint = (path, method, action, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -140,7 +127,21 @@ class SchemaGenerator(object): ) api_endpoints.extend(nested_endpoints) - return api_endpoints + return self.add_categories(api_endpoints) + + def add_categories(self, api_endpoints): + """ + (path, method, action, callback) -> (path, method, category, action, callback) + """ + # Determine the top level categories for the schema content, + # based on the URLs of the endpoints. Eg `set(['users', 'organisations'])` + paths = [endpoint[0] for endpoint in api_endpoints] + categories = self.get_categories(paths) + + return [ + (path, method, self.get_category(categories, path), action, callback) + for (path, method, action, callback) in api_endpoints + ] def get_path(self, path_regex): """ @@ -177,23 +178,38 @@ class SchemaGenerator(object): callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') ] - def get_key(self, path, method, callback): + def get_action(self, path, method, callback): """ - Return a tuple of strings, indicating the identity to use for a - given endpoint. eg. ('users', 'list'). + Return a description action string for the endpoint, eg. 'list'. """ - category = None - for item in path.strip('/').split('/'): - if '{' in item: - break - category = item - actions = getattr(callback, 'actions', self.default_mapping) - action = actions[method.lower()] + return actions[method.lower()] - if category: - return (category, action) - return (action,) + def get_categories(self, paths): + categories = set() + split_paths = set([ + tuple(path.split("{")[0].strip('/').split('/')) + for path in paths + ]) + + while split_paths: + for split_path in list(split_paths): + if len(split_path) == 0: + split_paths.remove(split_path) + elif len(split_path) == 1: + categories.add(split_path[0]) + split_paths.remove(split_path) + elif split_path[0] in categories: + split_paths.remove(split_path) + + return categories + + def get_category(self, categories, path): + path_components = path.split("{")[0].strip('/').split('/') + for path_component in path_components: + if path_component in categories: + return path_component + return None # Methods for generating each individual `Link` instance... diff --git a/tests/test_schemas.py b/tests/test_schemas.py index d8c0f2209..5e588483d 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,7 +5,7 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi -from rest_framework.decorators import detail_route +from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator @@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet): def custom_action(self, request, pk): return super(ExampleSerializer, self).retrieve(self, request) + @list_route() + def custom_list_action(self, request): + return super(ExampleViewSet, self).list(self, request) + def get_serializer(self, *args, **kwargs): assert self.request return super(ExampleViewSet, self).get_serializer(*args, **kwargs) @@ -88,6 +92,10 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('ordering', required=False, location='query') ] ), + 'custom_list_action': coreapi.Link( + url='/example/custom_list_action/', + action='get' + ), 'retrieve': coreapi.Link( url='/example/{pk}/', action='get', @@ -144,6 +152,10 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('d', required=False, location='form'), ] ), + 'custom_list_action': coreapi.Link( + url='/example/custom_list_action/', + action='get' + ), 'update': coreapi.Link( url='/example/{pk}/', action='put',