Fix schema categories for custom list actions (#4386)

This commit is contained in:
Tom Christie 2016-08-11 14:07:40 +01:00 committed by GitHub
parent b50d8950ee
commit 01b498ec51
2 changed files with 70 additions and 42 deletions

View File

@ -30,24 +30,6 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView)
def insert_into(target, keys, item):
"""
Insert `item` into the nested dictionary `target`.
For example:
target = {}
insert_into(target, ('users', 'list'), Link(...))
insert_into(target, ('users', 'detail'), Link(...))
assert target == {'users': {'list': Link(...), 'detail': Link(...)}}
"""
for key in keys[:1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = item
class SchemaGenerator(object):
default_mapping = {
'get': 'read',
@ -84,7 +66,7 @@ class SchemaGenerator(object):
self.endpoints = self.get_api_endpoints(self.patterns)
links = []
for key, path, method, callback in self.endpoints:
for path, method, category, action, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
@ -102,16 +84,21 @@ class SchemaGenerator(object):
view.request = None
link = self.get_link(path, method, callback, view)
links.append((key, link))
links.append((category, action, link))
if not link:
if not links:
return None
# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
# Generate the schema content structure, eg:
# {'users': {'list': Link()}}
content = {}
for key, link in links:
insert_into(content, key, link)
for category, action, link in links:
if category is None:
content[action] = link
elif category in content:
content[category][action] = link
else:
content[category] = {action: link}
# Return the schema document.
return coreapi.Document(title=self.title, content=content, url=self.url)
@ -129,8 +116,8 @@ class SchemaGenerator(object):
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
endpoint = (key, path, method, callback)
action = self.get_action(path, method, callback)
endpoint = (path, method, action, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
@ -140,7 +127,21 @@ class SchemaGenerator(object):
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
return self.add_categories(api_endpoints)
def add_categories(self, api_endpoints):
"""
(path, method, action, callback) -> (path, method, category, action, callback)
"""
# Determine the top level categories for the schema content,
# based on the URLs of the endpoints. Eg `set(['users', 'organisations'])`
paths = [endpoint[0] for endpoint in api_endpoints]
categories = self.get_categories(paths)
return [
(path, method, self.get_category(categories, path), action, callback)
for (path, method, action, callback) in api_endpoints
]
def get_path(self, path_regex):
"""
@ -177,23 +178,38 @@ class SchemaGenerator(object):
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
def get_key(self, path, method, callback):
def get_action(self, path, method, callback):
"""
Return a tuple of strings, indicating the identity to use for a
given endpoint. eg. ('users', 'list').
Return a description action string for the endpoint, eg. 'list'.
"""
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item
actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()]
return actions[method.lower()]
if category:
return (category, action)
return (action,)
def get_categories(self, paths):
categories = set()
split_paths = set([
tuple(path.split("{")[0].strip('/').split('/'))
for path in paths
])
while split_paths:
for split_path in list(split_paths):
if len(split_path) == 0:
split_paths.remove(split_path)
elif len(split_path) == 1:
categories.add(split_path[0])
split_paths.remove(split_path)
elif split_path[0] in categories:
split_paths.remove(split_path)
return categories
def get_category(self, categories, path):
path_components = path.split("{")[0].strip('/').split('/')
for path_component in path_components:
if path_component in categories:
return path_component
return None
# Methods for generating each individual `Link` instance...

View File

@ -5,7 +5,7 @@ from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi
from rest_framework.decorators import detail_route
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)
@list_route()
def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request)
def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
@ -88,6 +92,10 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('ordering', required=False, location='query')
]
),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
action='get',
@ -144,6 +152,10 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('d', required=False, location='form'),
]
),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'update': coreapi.Link(
url='/example/{pk}/',
action='put',