From 8c721129894a77b7ed6b7eeebc049d47542a2e57 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Sep 2016 11:23:36 +0100 Subject: [PATCH] Clean up schema generation --- rest_framework/schemas.py | 309 ++++++++++++++++++++++---------------- tests/test_schemas.py | 8 +- 2 files changed, 185 insertions(+), 132 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1b899450f..fdd8a1b1e 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -31,106 +31,59 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -class SchemaGenerator(object): - default_mapping = { - 'get': 'read', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - known_actions = ( - 'create', 'read', 'retrieve', 'list', - 'update', 'partial_update', 'destroy' - ) +def insert_into(target, keys, value): + """ + Nested dictionary insertion. - def __init__(self, title=None, url=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + {'a': {'b': {'c': 123}}} + """ + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value - if patterns is None and urlconf is not None: - if isinstance(urlconf, six.string_types): - urls = import_module(urlconf) + +class EndpointInspector(object): + """ + A class to determine the available API endpoints that a project exposes. + """ + def __init__(self, patterns=None, urlconf=None): + if patterns is None: + if urlconf is None: + # Use the default Django URL conf + urls = import_module(settings.ROOT_URLCONF) + patterns = urls.urlpatterns else: - urls = urlconf - self.patterns = urls.urlpatterns - elif patterns is None and urlconf is None: - urls = import_module(settings.ROOT_URLCONF) - self.patterns = urls.urlpatterns - else: - self.patterns = patterns - - if url and not url.endswith('/'): - url += '/' - - self.title = title - self.url = url - self.endpoints = None - - def get_schema(self, request=None): - if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) - - links = [] - for path, method, category, action, callback in self.endpoints: - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) - view.args = () - view.kwargs = {} - view.format_kwarg = None - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' + # Load the given URLconf module + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) else: - view.action = actions.get(method.lower()) + urls = urlconf + patterns = urls.urlpatterns - if request is not None: - view.request = clone_request(request, method) - try: - view.check_permissions(view.request) - except exceptions.APIException: - continue - else: - view.request = None + self.patterns = patterns - link = self.get_link(path, method, callback, view) - links.append((category, action, link)) - - if not links: - return None - - # Generate the schema content structure, eg: - # {'users': {'list': Link()}} - content = {} - for category, action, link in links: - if category is None: - content[action] = link - elif category in content: - content[category][action] = link - else: - content[category] = {action: link} - - # Return the schema document. - return coreapi.Document(title=self.title, content=content, url=self.url) - - def get_api_endpoints(self, patterns, prefix=''): + def get_api_endpoints(self, patterns=None, prefix=''): """ Return a list of all available API endpoints by inspecting the URL conf. """ + if patterns is None: + patterns = self.patterns + api_endpoints = [] for pattern in patterns: path_regex = prefix + pattern.regex.pattern if isinstance(pattern, RegexURLPattern): - path = self.get_path(path_regex) + path = self.get_path_from_regex(path_regex) callback = pattern.callback if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): - action = self.get_action(path, method, callback) - category = self.get_category(path, method, callback, action) - endpoint = (path, method, category, action, callback) + endpoint = (path, method, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -142,7 +95,7 @@ class SchemaGenerator(object): return api_endpoints - def get_path(self, path_regex): + def get_path_from_regex(self, path_regex): """ Given a URL conf regex, return a URI template string. """ @@ -177,57 +130,94 @@ class SchemaGenerator(object): callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') ] - def get_action(self, path, method, callback): - """ - Return a descriptive action string for the endpoint, eg. 'list'. - """ - actions = getattr(callback, 'actions', self.default_mapping) - return actions[method.lower()] - def get_category(self, path, method, callback, action): - """ - Return a descriptive category string for the endpoint, eg. 'users'. +class SchemaGenerator(object): + endpoint_inspector_cls = EndpointInspector - Examples of category/action pairs that should be generated for various - endpoints: + def __init__(self, title=None, url=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' - /users/ [users][list], [users][create] - /users/{pk}/ [users][read], [users][update], [users][destroy] - /users/enabled/ [users][enabled] (custom action) - /users/{pk}/star/ [users][star] (custom action) - /users/{pk}/groups/ [groups][list], [groups][create] - /users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy] + if url and not url.endswith('/'): + url += '/' + + self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf) + self.title = title + self.url = url + self.endpoints = None + + def get_schema(self, request=None): """ - path_components = path.strip('/').split('/') - path_components = [ - component for component in path_components - if '{' not in component - ] - if action in self.known_actions: - # Default action, eg "/users/", "/users/{pk}/" - idx = -1 - else: - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - idx = -2 + Generate a `coreapi.Document` representing the API schema. + """ + if self.endpoints is None: + self.endpoints = self.endpoint_inspector.get_api_endpoints() + + links = {} + for path, method, callback in self.endpoints: + view = self.create_view(callback, method, request) + if not self.has_view_permissions(view): + continue + link = self.get_link(path, method, view) + keys = self.get_keys(path, method, view) + insert_into(links, keys, link) + + if not links: + return None + + return coreapi.Document(title=self.title, url=self.url, content=links) + + # Methods used when we generate a view instance from the raw callback... + + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + + return view + + def has_view_permissions(self, view): + """ + Return `True` if the incoming request has the correct view permissions. + """ + if view.request is None: + return True try: - return path_components[idx] - except IndexError: - return None + view.check_permissions(view.request) + except exceptions.APIException: + return False + return True # Methods for generating each individual `Link` instance... - def get_link(self, path, method, callback, view): + def get_link(self, path, method, view): """ Return a `coreapi.Link` instance for the given endpoint. """ - fields = self.get_path_fields(path, method, callback, view) - fields += self.get_serializer_fields(path, method, callback, view) - fields += self.get_pagination_fields(path, method, callback, view) - fields += self.get_filter_fields(path, method, callback, view) + fields = self.get_path_fields(path, method, view) + fields += self.get_serializer_fields(path, method, view) + fields += self.get_pagination_fields(path, method, view) + fields += self.get_filter_fields(path, method, view) if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method, callback, view) + encoding = self.get_encoding(path, method, view) else: encoding = None @@ -241,7 +231,7 @@ class SchemaGenerator(object): fields=fields ) - def get_encoding(self, path, method, callback, view): + def get_encoding(self, path, method, view): """ Return the 'encoding' parameter to use for a given endpoint. """ @@ -262,7 +252,7 @@ class SchemaGenerator(object): return None - def get_path_fields(self, path, method, callback, view): + def get_path_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any templated path variables. @@ -275,7 +265,7 @@ class SchemaGenerator(object): return fields - def get_serializer_fields(self, path, method, callback, view): + def get_serializer_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any request body input, as determined by the serializer class. @@ -311,11 +301,11 @@ class SchemaGenerator(object): return fields - def get_pagination_fields(self, path, method, callback, view): + def get_pagination_fields(self, path, method, view): if method != 'GET': return [] - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): + if getattr(view, 'action', 'list') != 'list': return [] if not getattr(view, 'pagination_class', None): @@ -324,11 +314,11 @@ class SchemaGenerator(object): paginator = view.pagination_class() return as_query_fields(paginator.get_fields(view)) - def get_filter_fields(self, path, method, callback, view): + def get_filter_fields(self, path, method, view): if method != 'GET': return [] - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): + if getattr(view, 'action', 'list') != 'list': return [] if not hasattr(view, 'filter_backends'): @@ -338,3 +328,66 @@ class SchemaGenerator(object): for filter_backend in view.filter_backends: fields += as_query_fields(filter_backend().get_fields(view)) return fields + + # Methods for generating the keys which are used to layout each link. + + default_mapping = { + 'get': 'read', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + coerce_actions = { + 'retrieve': 'read', + 'destroy': 'delete' + } + known_actions = set([ + 'create', 'read', 'list', 'update', 'partial_update', 'delete' + ]) + + def get_keys(self, path, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "enabled") # custom viewset detail action + /users/{pk}/groups/ ("groups", "list"), ("groups", "create") + /users/{pk}/groups/{pk}/ ("groups", "read"), ("groups", "update"), ("groups", "delete") + """ + path_components = path.strip('/').split('/') + named_path_components = [ + component for component in path_components + if '{' not in component + ] + + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + action = view.action + # The default views use some naming that isn't well suited to what + # we'd actually like for the schema representation. + if action in self.coerce_actions: + action = self.coerce_actions[action] + else: + # Views have no associated action, so we determine one from the method. + method = method.lower() + if method == 'get': + is_detail = path_components and ('{' in path_components[-1]) + action = 'read' if is_detail else 'list' + else: + action = self.default_mapping[method] + + if action in self.known_actions: + # Default action, eg "/users/", "/users/{pk}/" + idx = -1 + else: + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + idx = -2 + + try: + return (named_path_components[idx], action) + except IndexError: + return (action,) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 05c388c08..5d2b58a86 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -99,7 +99,7 @@ class TestRouterGeneratedSchema(TestCase): url='/example/custom_list_action/', action='get' ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -138,7 +138,7 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('b', required=False, location='form') ] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -179,7 +179,7 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('b', required=False, location='form') ] ), - 'destroy': coreapi.Link( + 'delete': coreapi.Link( url='/example/{pk}/', action='delete', fields=[ @@ -207,7 +207,7 @@ class TestSchemaGenerator(TestCase): action='post', fields=[] ), - 'read': coreapi.Link( + 'list': coreapi.Link( url='/example-view/', action='get', fields=[]