mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-09 08:00:52 +03:00
Fix schema categories for custom list actions (#4386)
This commit is contained in:
parent
b50d8950ee
commit
01b498ec51
|
@ -30,24 +30,6 @@ def is_api_view(callback):
|
||||||
return (cls is not None) and issubclass(cls, APIView)
|
return (cls is not None) and issubclass(cls, APIView)
|
||||||
|
|
||||||
|
|
||||||
def insert_into(target, keys, item):
|
|
||||||
"""
|
|
||||||
Insert `item` into the nested dictionary `target`.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
target = {}
|
|
||||||
insert_into(target, ('users', 'list'), Link(...))
|
|
||||||
insert_into(target, ('users', 'detail'), Link(...))
|
|
||||||
assert target == {'users': {'list': Link(...), 'detail': Link(...)}}
|
|
||||||
"""
|
|
||||||
for key in keys[:1]:
|
|
||||||
if key not in target:
|
|
||||||
target[key] = {}
|
|
||||||
target = target[key]
|
|
||||||
target[keys[-1]] = item
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaGenerator(object):
|
class SchemaGenerator(object):
|
||||||
default_mapping = {
|
default_mapping = {
|
||||||
'get': 'read',
|
'get': 'read',
|
||||||
|
@ -84,7 +66,7 @@ class SchemaGenerator(object):
|
||||||
self.endpoints = self.get_api_endpoints(self.patterns)
|
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||||
|
|
||||||
links = []
|
links = []
|
||||||
for key, path, method, callback in self.endpoints:
|
for path, method, category, action, callback in self.endpoints:
|
||||||
view = callback.cls()
|
view = callback.cls()
|
||||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||||
setattr(view, attr, val)
|
setattr(view, attr, val)
|
||||||
|
@ -102,16 +84,21 @@ class SchemaGenerator(object):
|
||||||
view.request = None
|
view.request = None
|
||||||
|
|
||||||
link = self.get_link(path, method, callback, view)
|
link = self.get_link(path, method, callback, view)
|
||||||
links.append((key, link))
|
links.append((category, action, link))
|
||||||
|
|
||||||
if not link:
|
if not links:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Generate the schema content structure, from the endpoints.
|
# Generate the schema content structure, eg:
|
||||||
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
# {'users': {'list': Link()}}
|
||||||
content = {}
|
content = {}
|
||||||
for key, link in links:
|
for category, action, link in links:
|
||||||
insert_into(content, key, link)
|
if category is None:
|
||||||
|
content[action] = link
|
||||||
|
elif category in content:
|
||||||
|
content[category][action] = link
|
||||||
|
else:
|
||||||
|
content[category] = {action: link}
|
||||||
|
|
||||||
# Return the schema document.
|
# Return the schema document.
|
||||||
return coreapi.Document(title=self.title, content=content, url=self.url)
|
return coreapi.Document(title=self.title, content=content, url=self.url)
|
||||||
|
@ -129,8 +116,8 @@ class SchemaGenerator(object):
|
||||||
callback = pattern.callback
|
callback = pattern.callback
|
||||||
if self.should_include_endpoint(path, callback):
|
if self.should_include_endpoint(path, callback):
|
||||||
for method in self.get_allowed_methods(callback):
|
for method in self.get_allowed_methods(callback):
|
||||||
key = self.get_key(path, method, callback)
|
action = self.get_action(path, method, callback)
|
||||||
endpoint = (key, path, method, callback)
|
endpoint = (path, method, action, callback)
|
||||||
api_endpoints.append(endpoint)
|
api_endpoints.append(endpoint)
|
||||||
|
|
||||||
elif isinstance(pattern, RegexURLResolver):
|
elif isinstance(pattern, RegexURLResolver):
|
||||||
|
@ -140,7 +127,21 @@ class SchemaGenerator(object):
|
||||||
)
|
)
|
||||||
api_endpoints.extend(nested_endpoints)
|
api_endpoints.extend(nested_endpoints)
|
||||||
|
|
||||||
return api_endpoints
|
return self.add_categories(api_endpoints)
|
||||||
|
|
||||||
|
def add_categories(self, api_endpoints):
|
||||||
|
"""
|
||||||
|
(path, method, action, callback) -> (path, method, category, action, callback)
|
||||||
|
"""
|
||||||
|
# Determine the top level categories for the schema content,
|
||||||
|
# based on the URLs of the endpoints. Eg `set(['users', 'organisations'])`
|
||||||
|
paths = [endpoint[0] for endpoint in api_endpoints]
|
||||||
|
categories = self.get_categories(paths)
|
||||||
|
|
||||||
|
return [
|
||||||
|
(path, method, self.get_category(categories, path), action, callback)
|
||||||
|
for (path, method, action, callback) in api_endpoints
|
||||||
|
]
|
||||||
|
|
||||||
def get_path(self, path_regex):
|
def get_path(self, path_regex):
|
||||||
"""
|
"""
|
||||||
|
@ -177,23 +178,38 @@ class SchemaGenerator(object):
|
||||||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_key(self, path, method, callback):
|
def get_action(self, path, method, callback):
|
||||||
"""
|
"""
|
||||||
Return a tuple of strings, indicating the identity to use for a
|
Return a description action string for the endpoint, eg. 'list'.
|
||||||
given endpoint. eg. ('users', 'list').
|
|
||||||
"""
|
"""
|
||||||
category = None
|
|
||||||
for item in path.strip('/').split('/'):
|
|
||||||
if '{' in item:
|
|
||||||
break
|
|
||||||
category = item
|
|
||||||
|
|
||||||
actions = getattr(callback, 'actions', self.default_mapping)
|
actions = getattr(callback, 'actions', self.default_mapping)
|
||||||
action = actions[method.lower()]
|
return actions[method.lower()]
|
||||||
|
|
||||||
if category:
|
def get_categories(self, paths):
|
||||||
return (category, action)
|
categories = set()
|
||||||
return (action,)
|
split_paths = set([
|
||||||
|
tuple(path.split("{")[0].strip('/').split('/'))
|
||||||
|
for path in paths
|
||||||
|
])
|
||||||
|
|
||||||
|
while split_paths:
|
||||||
|
for split_path in list(split_paths):
|
||||||
|
if len(split_path) == 0:
|
||||||
|
split_paths.remove(split_path)
|
||||||
|
elif len(split_path) == 1:
|
||||||
|
categories.add(split_path[0])
|
||||||
|
split_paths.remove(split_path)
|
||||||
|
elif split_path[0] in categories:
|
||||||
|
split_paths.remove(split_path)
|
||||||
|
|
||||||
|
return categories
|
||||||
|
|
||||||
|
def get_category(self, categories, path):
|
||||||
|
path_components = path.split("{")[0].strip('/').split('/')
|
||||||
|
for path_component in path_components:
|
||||||
|
if path_component in categories:
|
||||||
|
return path_component
|
||||||
|
return None
|
||||||
|
|
||||||
# Methods for generating each individual `Link` instance...
|
# Methods for generating each individual `Link` instance...
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from django.test import TestCase, override_settings
|
||||||
|
|
||||||
from rest_framework import filters, pagination, permissions, serializers
|
from rest_framework import filters, pagination, permissions, serializers
|
||||||
from rest_framework.compat import coreapi
|
from rest_framework.compat import coreapi
|
||||||
from rest_framework.decorators import detail_route
|
from rest_framework.decorators import detail_route, list_route
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
from rest_framework.schemas import SchemaGenerator
|
from rest_framework.schemas import SchemaGenerator
|
||||||
|
@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
|
||||||
def custom_action(self, request, pk):
|
def custom_action(self, request, pk):
|
||||||
return super(ExampleSerializer, self).retrieve(self, request)
|
return super(ExampleSerializer, self).retrieve(self, request)
|
||||||
|
|
||||||
|
@list_route()
|
||||||
|
def custom_list_action(self, request):
|
||||||
|
return super(ExampleViewSet, self).list(self, request)
|
||||||
|
|
||||||
def get_serializer(self, *args, **kwargs):
|
def get_serializer(self, *args, **kwargs):
|
||||||
assert self.request
|
assert self.request
|
||||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||||
|
@ -88,6 +92,10 @@ class TestRouterGeneratedSchema(TestCase):
|
||||||
coreapi.Field('ordering', required=False, location='query')
|
coreapi.Field('ordering', required=False, location='query')
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
'custom_list_action': coreapi.Link(
|
||||||
|
url='/example/custom_list_action/',
|
||||||
|
action='get'
|
||||||
|
),
|
||||||
'retrieve': coreapi.Link(
|
'retrieve': coreapi.Link(
|
||||||
url='/example/{pk}/',
|
url='/example/{pk}/',
|
||||||
action='get',
|
action='get',
|
||||||
|
@ -144,6 +152,10 @@ class TestRouterGeneratedSchema(TestCase):
|
||||||
coreapi.Field('d', required=False, location='form'),
|
coreapi.Field('d', required=False, location='form'),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
'custom_list_action': coreapi.Link(
|
||||||
|
url='/example/custom_list_action/',
|
||||||
|
action='get'
|
||||||
|
),
|
||||||
'update': coreapi.Link(
|
'update': coreapi.Link(
|
||||||
url='/example/{pk}/',
|
url='/example/{pk}/',
|
||||||
action='put',
|
action='put',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user