mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Clean up schema generation
This commit is contained in:
parent
a8b46fa013
commit
8c72112989
|
@ -31,106 +31,59 @@ def is_api_view(callback):
|
|||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
known_actions = (
|
||||
'create', 'read', 'retrieve', 'list',
|
||||
'update', 'partial_update', 'destroy'
|
||||
)
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
{'a': {'b': {'c': 123}}}
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = value
|
||||
|
||||
if patterns is None and urlconf is not None:
|
||||
|
||||
class EndpointInspector(object):
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urls = import_module(settings.ROOT_URLCONF)
|
||||
patterns = urls.urlpatterns
|
||||
else:
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, six.string_types):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
self.patterns = urls.urlpatterns
|
||||
elif patterns is None and urlconf is None:
|
||||
urls = import_module(settings.ROOT_URLCONF)
|
||||
self.patterns = urls.urlpatterns
|
||||
else:
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None):
|
||||
if self.endpoints is None:
|
||||
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||
|
||||
links = []
|
||||
for path, method, category, action, callback in self.endpoints:
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
continue
|
||||
else:
|
||||
view.request = None
|
||||
|
||||
link = self.get_link(path, method, callback, view)
|
||||
links.append((category, action, link))
|
||||
|
||||
if not links:
|
||||
return None
|
||||
|
||||
# Generate the schema content structure, eg:
|
||||
# {'users': {'list': Link()}}
|
||||
content = {}
|
||||
for category, action, link in links:
|
||||
if category is None:
|
||||
content[action] = link
|
||||
elif category in content:
|
||||
content[category][action] = link
|
||||
else:
|
||||
content[category] = {action: link}
|
||||
|
||||
# Return the schema document.
|
||||
return coreapi.Document(title=self.title, content=content, url=self.url)
|
||||
|
||||
def get_api_endpoints(self, patterns, prefix=''):
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + pattern.regex.pattern
|
||||
if isinstance(pattern, RegexURLPattern):
|
||||
path = self.get_path(path_regex)
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
action = self.get_action(path, method, callback)
|
||||
category = self.get_category(path, method, callback, action)
|
||||
endpoint = (path, method, category, action, callback)
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
|
@ -142,7 +95,7 @@ class SchemaGenerator(object):
|
|||
|
||||
return api_endpoints
|
||||
|
||||
def get_path(self, path_regex):
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
|
@ -177,57 +130,94 @@ class SchemaGenerator(object):
|
|||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||
]
|
||||
|
||||
def get_action(self, path, method, callback):
|
||||
"""
|
||||
Return a descriptive action string for the endpoint, eg. 'list'.
|
||||
"""
|
||||
actions = getattr(callback, 'actions', self.default_mapping)
|
||||
return actions[method.lower()]
|
||||
|
||||
def get_category(self, path, method, callback, action):
|
||||
"""
|
||||
Return a descriptive category string for the endpoint, eg. 'users'.
|
||||
class SchemaGenerator(object):
|
||||
endpoint_inspector_cls = EndpointInspector
|
||||
|
||||
Examples of category/action pairs that should be generated for various
|
||||
endpoints:
|
||||
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
|
||||
/users/ [users][list], [users][create]
|
||||
/users/{pk}/ [users][read], [users][update], [users][destroy]
|
||||
/users/enabled/ [users][enabled] (custom action)
|
||||
/users/{pk}/star/ [users][star] (custom action)
|
||||
/users/{pk}/groups/ [groups][list], [groups][create]
|
||||
/users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy]
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf)
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None):
|
||||
"""
|
||||
path_components = path.strip('/').split('/')
|
||||
path_components = [
|
||||
component for component in path_components
|
||||
if '{' not in component
|
||||
]
|
||||
if action in self.known_actions:
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
idx = -1
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
if self.endpoints is None:
|
||||
self.endpoints = self.endpoint_inspector.get_api_endpoints()
|
||||
|
||||
links = {}
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
if not self.has_view_permissions(view):
|
||||
continue
|
||||
link = self.get_link(path, method, view)
|
||||
keys = self.get_keys(path, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
if not links:
|
||||
return None
|
||||
|
||||
return coreapi.Document(title=self.title, url=self.url, content=links)
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
idx = -2
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
return path_components[idx]
|
||||
except IndexError:
|
||||
return None
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Methods for generating each individual `Link` instance...
|
||||
|
||||
def get_link(self, path, method, callback, view):
|
||||
def get_link(self, path, method, view):
|
||||
"""
|
||||
Return a `coreapi.Link` instance for the given endpoint.
|
||||
"""
|
||||
fields = self.get_path_fields(path, method, callback, view)
|
||||
fields += self.get_serializer_fields(path, method, callback, view)
|
||||
fields += self.get_pagination_fields(path, method, callback, view)
|
||||
fields += self.get_filter_fields(path, method, callback, view)
|
||||
fields = self.get_path_fields(path, method, view)
|
||||
fields += self.get_serializer_fields(path, method, view)
|
||||
fields += self.get_pagination_fields(path, method, view)
|
||||
fields += self.get_filter_fields(path, method, view)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method, callback, view)
|
||||
encoding = self.get_encoding(path, method, view)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
|
@ -241,7 +231,7 @@ class SchemaGenerator(object):
|
|||
fields=fields
|
||||
)
|
||||
|
||||
def get_encoding(self, path, method, callback, view):
|
||||
def get_encoding(self, path, method, view):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
|
@ -262,7 +252,7 @@ class SchemaGenerator(object):
|
|||
|
||||
return None
|
||||
|
||||
def get_path_fields(self, path, method, callback, view):
|
||||
def get_path_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
|
@ -275,7 +265,7 @@ class SchemaGenerator(object):
|
|||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method, callback, view):
|
||||
def get_serializer_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
|
@ -311,11 +301,11 @@ class SchemaGenerator(object):
|
|||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method, callback, view):
|
||||
def get_pagination_fields(self, path, method, view):
|
||||
if method != 'GET':
|
||||
return []
|
||||
|
||||
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
|
||||
if getattr(view, 'action', 'list') != 'list':
|
||||
return []
|
||||
|
||||
if not getattr(view, 'pagination_class', None):
|
||||
|
@ -324,11 +314,11 @@ class SchemaGenerator(object):
|
|||
paginator = view.pagination_class()
|
||||
return as_query_fields(paginator.get_fields(view))
|
||||
|
||||
def get_filter_fields(self, path, method, callback, view):
|
||||
def get_filter_fields(self, path, method, view):
|
||||
if method != 'GET':
|
||||
return []
|
||||
|
||||
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
|
||||
if getattr(view, 'action', 'list') != 'list':
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'filter_backends'):
|
||||
|
@ -338,3 +328,66 @@ class SchemaGenerator(object):
|
|||
for filter_backend in view.filter_backends:
|
||||
fields += as_query_fields(filter_backend().get_fields(view))
|
||||
return fields
|
||||
|
||||
# Methods for generating the keys which are used to layout each link.
|
||||
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
coerce_actions = {
|
||||
'retrieve': 'read',
|
||||
'destroy': 'delete'
|
||||
}
|
||||
known_actions = set([
|
||||
'create', 'read', 'list', 'update', 'partial_update', 'delete'
|
||||
])
|
||||
|
||||
def get_keys(self, path, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "enabled") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("groups", "list"), ("groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("groups", "read"), ("groups", "update"), ("groups", "delete")
|
||||
"""
|
||||
path_components = path.strip('/').split('/')
|
||||
named_path_components = [
|
||||
component for component in path_components
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
# The default views use some naming that isn't well suited to what
|
||||
# we'd actually like for the schema representation.
|
||||
if action in self.coerce_actions:
|
||||
action = self.coerce_actions[action]
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
method = method.lower()
|
||||
if method == 'get':
|
||||
is_detail = path_components and ('{' in path_components[-1])
|
||||
action = 'read' if is_detail else 'list'
|
||||
else:
|
||||
action = self.default_mapping[method]
|
||||
|
||||
if action in self.known_actions:
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
idx = -1
|
||||
else:
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
idx = -2
|
||||
|
||||
try:
|
||||
return (named_path_components[idx], action)
|
||||
except IndexError:
|
||||
return (action,)
|
||||
|
|
|
@ -99,7 +99,7 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
url='/example/custom_list_action/',
|
||||
action='get'
|
||||
),
|
||||
'retrieve': coreapi.Link(
|
||||
'read': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
action='get',
|
||||
fields=[
|
||||
|
@ -138,7 +138,7 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
coreapi.Field('b', required=False, location='form')
|
||||
]
|
||||
),
|
||||
'retrieve': coreapi.Link(
|
||||
'read': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
action='get',
|
||||
fields=[
|
||||
|
@ -179,7 +179,7 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
coreapi.Field('b', required=False, location='form')
|
||||
]
|
||||
),
|
||||
'destroy': coreapi.Link(
|
||||
'delete': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
action='delete',
|
||||
fields=[
|
||||
|
@ -207,7 +207,7 @@ class TestSchemaGenerator(TestCase):
|
|||
action='post',
|
||||
fields=[]
|
||||
),
|
||||
'read': coreapi.Link(
|
||||
'list': coreapi.Link(
|
||||
url='/example-view/',
|
||||
action='get',
|
||||
fields=[]
|
||||
|
|
Loading…
Reference in New Issue
Block a user