mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-03 13:14:30 +03:00
Fix schema categories for custom list actions (#4386)
This commit is contained in:
parent
b50d8950ee
commit
01b498ec51
|
@ -30,24 +30,6 @@ def is_api_view(callback):
|
|||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
def insert_into(target, keys, item):
|
||||
"""
|
||||
Insert `item` into the nested dictionary `target`.
|
||||
|
||||
For example:
|
||||
|
||||
target = {}
|
||||
insert_into(target, ('users', 'list'), Link(...))
|
||||
insert_into(target, ('users', 'detail'), Link(...))
|
||||
assert target == {'users': {'list': Link(...), 'detail': Link(...)}}
|
||||
"""
|
||||
for key in keys[:1]:
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = item
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
|
@ -84,7 +66,7 @@ class SchemaGenerator(object):
|
|||
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||
|
||||
links = []
|
||||
for key, path, method, callback in self.endpoints:
|
||||
for path, method, category, action, callback in self.endpoints:
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
|
@ -102,16 +84,21 @@ class SchemaGenerator(object):
|
|||
view.request = None
|
||||
|
||||
link = self.get_link(path, method, callback, view)
|
||||
links.append((key, link))
|
||||
links.append((category, action, link))
|
||||
|
||||
if not link:
|
||||
if not links:
|
||||
return None
|
||||
|
||||
# Generate the schema content structure, from the endpoints.
|
||||
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
||||
# Generate the schema content structure, eg:
|
||||
# {'users': {'list': Link()}}
|
||||
content = {}
|
||||
for key, link in links:
|
||||
insert_into(content, key, link)
|
||||
for category, action, link in links:
|
||||
if category is None:
|
||||
content[action] = link
|
||||
elif category in content:
|
||||
content[category][action] = link
|
||||
else:
|
||||
content[category] = {action: link}
|
||||
|
||||
# Return the schema document.
|
||||
return coreapi.Document(title=self.title, content=content, url=self.url)
|
||||
|
@ -129,8 +116,8 @@ class SchemaGenerator(object):
|
|||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
key = self.get_key(path, method, callback)
|
||||
endpoint = (key, path, method, callback)
|
||||
action = self.get_action(path, method, callback)
|
||||
endpoint = (path, method, action, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
|
@ -140,7 +127,21 @@ class SchemaGenerator(object):
|
|||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
return api_endpoints
|
||||
return self.add_categories(api_endpoints)
|
||||
|
||||
def add_categories(self, api_endpoints):
|
||||
"""
|
||||
(path, method, action, callback) -> (path, method, category, action, callback)
|
||||
"""
|
||||
# Determine the top level categories for the schema content,
|
||||
# based on the URLs of the endpoints. Eg `set(['users', 'organisations'])`
|
||||
paths = [endpoint[0] for endpoint in api_endpoints]
|
||||
categories = self.get_categories(paths)
|
||||
|
||||
return [
|
||||
(path, method, self.get_category(categories, path), action, callback)
|
||||
for (path, method, action, callback) in api_endpoints
|
||||
]
|
||||
|
||||
def get_path(self, path_regex):
|
||||
"""
|
||||
|
@ -177,23 +178,38 @@ class SchemaGenerator(object):
|
|||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||
]
|
||||
|
||||
def get_key(self, path, method, callback):
|
||||
def get_action(self, path, method, callback):
|
||||
"""
|
||||
Return a tuple of strings, indicating the identity to use for a
|
||||
given endpoint. eg. ('users', 'list').
|
||||
Return a description action string for the endpoint, eg. 'list'.
|
||||
"""
|
||||
category = None
|
||||
for item in path.strip('/').split('/'):
|
||||
if '{' in item:
|
||||
break
|
||||
category = item
|
||||
|
||||
actions = getattr(callback, 'actions', self.default_mapping)
|
||||
action = actions[method.lower()]
|
||||
return actions[method.lower()]
|
||||
|
||||
if category:
|
||||
return (category, action)
|
||||
return (action,)
|
||||
def get_categories(self, paths):
|
||||
categories = set()
|
||||
split_paths = set([
|
||||
tuple(path.split("{")[0].strip('/').split('/'))
|
||||
for path in paths
|
||||
])
|
||||
|
||||
while split_paths:
|
||||
for split_path in list(split_paths):
|
||||
if len(split_path) == 0:
|
||||
split_paths.remove(split_path)
|
||||
elif len(split_path) == 1:
|
||||
categories.add(split_path[0])
|
||||
split_paths.remove(split_path)
|
||||
elif split_path[0] in categories:
|
||||
split_paths.remove(split_path)
|
||||
|
||||
return categories
|
||||
|
||||
def get_category(self, categories, path):
|
||||
path_components = path.split("{")[0].strip('/').split('/')
|
||||
for path_component in path_components:
|
||||
if path_component in categories:
|
||||
return path_component
|
||||
return None
|
||||
|
||||
# Methods for generating each individual `Link` instance...
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from django.test import TestCase, override_settings
|
|||
|
||||
from rest_framework import filters, pagination, permissions, serializers
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.decorators import detail_route
|
||||
from rest_framework.decorators import detail_route, list_route
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
|
@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
|
|||
def custom_action(self, request, pk):
|
||||
return super(ExampleSerializer, self).retrieve(self, request)
|
||||
|
||||
@list_route()
|
||||
def custom_list_action(self, request):
|
||||
return super(ExampleViewSet, self).list(self, request)
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
assert self.request
|
||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||
|
@ -88,6 +92,10 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
coreapi.Field('ordering', required=False, location='query')
|
||||
]
|
||||
),
|
||||
'custom_list_action': coreapi.Link(
|
||||
url='/example/custom_list_action/',
|
||||
action='get'
|
||||
),
|
||||
'retrieve': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
action='get',
|
||||
|
@ -144,6 +152,10 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
coreapi.Field('d', required=False, location='form'),
|
||||
]
|
||||
),
|
||||
'custom_list_action': coreapi.Link(
|
||||
url='/example/custom_list_action/',
|
||||
action='get'
|
||||
),
|
||||
'update': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
action='put',
|
||||
|
|
Loading…
Reference in New Issue
Block a user