Restructure SchemaGenerator for easier subclassing

Allow adding new default list actions so that bulk actions can be
included in the schema with minimal changes.
This commit is contained in:
Ivan Anishchuk 2017-07-14 02:30:59 +08:00
parent 2a1fd3b45a
commit 9bcc8591f5
No known key found for this signature in database
GPG Key ID: F5B311EE98C75AC1

View File

@ -124,28 +124,6 @@ def insert_into(target, keys, value):
target[keys[-1]] = value target[keys[-1]] = value
def is_custom_action(action):
return action not in set([
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
])
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
def endpoint_ordering(endpoint): def endpoint_ordering(endpoint):
path, method, callback = endpoint path, method, callback = endpoint
method_priority = { method_priority = {
@ -265,6 +243,9 @@ class SchemaGenerator(object):
'patch': 'partial_update', 'patch': 'partial_update',
'delete': 'destroy', 'delete': 'destroy',
} }
default_list_mapping = {
'get': 'list',
}
endpoint_inspector_cls = EndpointInspector endpoint_inspector_cls = EndpointInspector
# Map the method names we use for viewset actions onto external schema names. # Map the method names we use for viewset actions onto external schema names.
@ -293,6 +274,10 @@ class SchemaGenerator(object):
self.description = description self.description = description
self.url = url self.url = url
self.endpoints = None self.endpoints = None
self.default_actions = set(
list(self.default_mapping.values()) +
list(self.default_list_mapping.values())
)
def get_schema(self, request=None, public=False): def get_schema(self, request=None, public=False):
""" """
@ -602,7 +587,9 @@ class SchemaGenerator(object):
return fields return fields
def get_pagination_fields(self, path, method, view): def get_pagination_fields(self, path, method, view):
if not is_list_view(path, method, view): if not self.is_list_view(path, method, view):
return []
if method.lower() != 'get':
return [] return []
pagination = getattr(view, 'pagination_class', None) pagination = getattr(view, 'pagination_class', None)
@ -613,7 +600,7 @@ class SchemaGenerator(object):
return paginator.get_schema_fields(view) return paginator.get_schema_fields(view)
def get_filter_fields(self, path, method, view): def get_filter_fields(self, path, method, view):
if not is_list_view(path, method, view): if not self.is_list_view(path, method, view):
return [] return []
if not getattr(view, 'filter_backends', None): if not getattr(view, 'filter_backends', None):
@ -643,8 +630,8 @@ class SchemaGenerator(object):
action = view.action action = view.action
else: else:
# Views have no associated action, so we determine one from the method. # Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view): if self.is_list_view(subpath, method, view):
action = 'list' action = self.default_list_mapping[method.lower()]
else: else:
action = self.default_mapping[method.lower()] action = self.default_mapping[method.lower()]
@ -654,7 +641,7 @@ class SchemaGenerator(object):
if '{' not in component if '{' not in component
] ]
if is_custom_action(action): if self.is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/" # Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1: if len(view.action_map) > 1:
action = self.default_mapping[method.lower()] action = self.default_mapping[method.lower()]
@ -670,6 +657,24 @@ class SchemaGenerator(object):
# Default action, eg "/users/", "/users/{pk}/" # Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action] return named_path_components + [action]
def is_custom_action(self, action):
return action not in self.default_actions
def is_list_view(self, path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action in self.default_list_mapping.values()
if method.lower() not in self.default_list_mapping:
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
class SchemaView(APIView): class SchemaView(APIView):
_ignore_model_permissions = True _ignore_model_permissions = True