diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 926e8db39..55e27ea8f 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -1,7 +1,7 @@ from django.core.management.base import BaseCommand from rest_framework.compat import yaml -from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.schemas.openapi import SchemaGenerator from rest_framework.utils import json @@ -15,7 +15,7 @@ class Command(BaseCommand): parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) def handle(self, *args, **options): - generator = OpenAPISchemaGenerator( + generator = SchemaGenerator( url=options['url'], title=options['title'], description=options['description'] diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index 322a2cedd..b1f37987a 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -22,9 +22,8 @@ Other access should target the submodules directly """ from rest_framework.settings import api_settings -from .generators import SchemaGenerator from .inspectors import DefaultSchema # noqa -from .coreapi import AutoSchema, ManualSchema # noqa +from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa def get_schema_view( diff --git a/rest_framework/schemas/coreapi.py b/rest_framework/schemas/coreapi.py index ed58b589a..895ea0efd 100644 --- a/rest_framework/schemas/coreapi.py +++ b/rest_framework/schemas/coreapi.py @@ -1,6 +1,6 @@ import re import warnings -from collections import OrderedDict +from collections import Counter, OrderedDict from django.db import models from django.utils.encoding import force_text, smart_text @@ -11,6 +11,7 @@ from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.settings import api_settings from rest_framework.utils import formatting +from .generators import BaseSchemaGenerator from .inspectors import ViewInspector from .utils import get_pk_description, is_list_view @@ -18,6 +19,198 @@ from .utils import get_pk_description, is_list_view # TODO: ???: move up to base. header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') +# Generator # +# TODO: Pull some of this into base. + + +def is_custom_action(action): + return action not in { + 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' + } + + +def distribute_links(obj): + for key, value in obj.items(): + distribute_links(value) + + for preferred_key, link in obj.links: + key = obj.get_available_key(preferred_key) + obj[key] = link + + +INSERT_INTO_COLLISION_FMT = """ +Schema Naming Collision. + +coreapi.Link for URL path {value_url} cannot be inserted into schema. +Position conflicts with coreapi.Link for URL path {target_url}. + +Attempted to insert link with keys: {keys}. + +Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` +to customise schema structure. +""" + + +class LinkNode(OrderedDict): + def __init__(self): + self.links = [] + self.methods_counter = Counter() + super(LinkNode, self).__init__() + + def get_available_key(self, preferred_key): + if preferred_key not in self: + return preferred_key + + while True: + current_val = self.methods_counter[preferred_key] + self.methods_counter[preferred_key] += 1 + + key = '{}_{}'.format(preferred_key, current_val) + if key not in self: + return key + + +def insert_into(target, keys, value): + """ + Nested dictionary insertion. + + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) + """ + for key in keys[:-1]: + if key not in target: + target[key] = LinkNode() + target = target[key] + + try: + target.links.append((keys[-1], value)) + except TypeError: + msg = INSERT_INTO_COLLISION_FMT.format( + value_url=value.url, + target_url=target.url, + keys=keys + ) + raise ValueError(msg) + + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) + + # Method for generating the link layout.... + def get_keys(self, subpath, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if is_list_view(subpath, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in subpath.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] + + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] + +# View Inspectors # + def field_to_schema(field): title = force_text(field.label) if field.label else '' diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 57cf91d16..3455ad1e0 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -4,7 +4,6 @@ generators.py # Top-down schema generation See schemas.__init__.py for package overview. """ import re -from collections import Counter, OrderedDict from importlib import import_module from django.conf import settings @@ -14,15 +13,11 @@ from django.http import Http404 from django.utils import six from rest_framework import exceptions -from rest_framework.compat import ( - URLPattern, URLResolver, coreapi, coreschema, get_original_route -) +from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.request import clone_request from rest_framework.settings import api_settings from rest_framework.utils.model_meta import _get_pk -from .utils import is_list_view - def common_path(paths): split_paths = [path.strip('/').split('/') for path in paths] @@ -51,78 +46,6 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -INSERT_INTO_COLLISION_FMT = """ -Schema Naming Collision. - -coreapi.Link for URL path {value_url} cannot be inserted into schema. -Position conflicts with coreapi.Link for URL path {target_url}. - -Attempted to insert link with keys: {keys}. - -Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` -to customise schema structure. -""" - - -class LinkNode(OrderedDict): - def __init__(self): - self.links = [] - self.methods_counter = Counter() - super(LinkNode, self).__init__() - - def get_available_key(self, preferred_key): - if preferred_key not in self: - return preferred_key - - while True: - current_val = self.methods_counter[preferred_key] - self.methods_counter[preferred_key] += 1 - - key = '{}_{}'.format(preferred_key, current_val) - if key not in self: - return key - - -def insert_into(target, keys, value): - """ - Nested dictionary insertion. - - >>> example = {} - >>> insert_into(example, ['a', 'b', 'c'], 123) - >>> example - LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) - """ - for key in keys[:-1]: - if key not in target: - target[key] = LinkNode() - target = target[key] - - try: - target.links.append((keys[-1], value)) - except TypeError: - msg = INSERT_INTO_COLLISION_FMT.format( - value_url=value.url, - target_url=target.url, - keys=keys - ) - raise ValueError(msg) - - -def distribute_links(obj): - for key, value in obj.items(): - distribute_links(value) - - for preferred_key, link in obj.links: - key = obj.get_available_key(preferred_key) - obj[key] = link - - -def is_custom_action(action): - return action not in { - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - } - - def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -360,170 +283,3 @@ class BaseSchemaGenerator(object): except (exceptions.APIException, Http404, PermissionDenied): return False return True - - -class SchemaGenerator(BaseSchemaGenerator): - """ - Original CoreAPI version. - """ - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - - super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - - def get_links(self, request=None): - """ - Return a dictionary containing all the links that should be - included in the API schema. - """ - links = LinkNode() - - paths, view_endpoints = self._get_paths_and_endpoints(request) - - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) - - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) - - return links - - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ - self._initialise_endpoints() - - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - # Method for generating the link layout.... - def get_keys(self, subpath, method, view): - """ - Return a list of keys that should be used to layout a link within - the schema document. - - /users/ ("users", "list"), ("users", "create") - /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") - /users/enabled/ ("users", "enabled") # custom viewset list action - /users/{pk}/star/ ("users", "star") # custom viewset detail action - /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") - /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") - """ - if hasattr(view, 'action'): - # Viewsets have explicitly named actions. - action = view.action - else: - # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' - else: - action = self.default_mapping[method.lower()] - - named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component - ] - - if is_custom_action(action): - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - if len(view.action_map) > 1: - action = self.default_mapping[method.lower()] - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - return named_path_components + [action] - else: - return named_path_components[:-1] + [action] - - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - - # Default action, eg "/users/", "/users/{pk}/" - return named_path_components + [action] - - -class OpenAPISchemaGenerator(BaseSchemaGenerator): - - def get_info(self): - info = { - 'title': self.title, - 'version': 'TODO', - } - - if self.description is not None: - info['description'] = self.description - - return info - - def get_paths(self, request=None): - result = {} - - paths, view_endpoints = self._get_paths_and_endpoints(request) - - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) - - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - operation = view.schema.get_operation(path, method) - subpath = '/' + path[len(prefix):] - result.setdefault(subpath, {}) - result[subpath][method.lower()] = operation - - return result - - def get_schema(self, request=None, public=False): - """ - Generate a OpenAPI schema. - """ - self._initialise_endpoints() - - paths = self.get_paths(None if public else request) - if not paths: - return None - - schema = { - 'openapi': '3.0.2', - 'info': self.get_info(), - 'paths': paths, - } - - return schema diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 10a2ce5d0..94de3bba2 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -6,9 +6,66 @@ from django.utils.encoding import force_text from rest_framework import exceptions, serializers from rest_framework.compat import uritemplate +from .generators import BaseSchemaGenerator from .inspectors import ViewInspector from .utils import get_pk_description, is_list_view +# Generator + + +class SchemaGenerator(BaseSchemaGenerator): + + def get_info(self): + info = { + 'title': self.title, + 'version': 'TODO', + } + + if self.description is not None: + info['description'] = self.description + + return info + + def get_paths(self, request=None): + result = {} + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + operation = view.schema.get_operation(path, method) + subpath = '/' + path[len(prefix):] + result.setdefault(subpath, {}) + result[subpath][method.lower()] = operation + + return result + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + + paths = self.get_paths(None if public else request) + if not paths: + return None + + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + return schema + +# View Inspectors + class AutoSchema(ViewInspector): diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 69b3bb6c9..6d98d919e 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -5,8 +5,7 @@ from django.test import RequestFactory, TestCase, override_settings from rest_framework import filters, generics, pagination, serializers from rest_framework.compat import uritemplate from rest_framework.request import Request -from rest_framework.schemas.generators import OpenAPISchemaGenerator -from rest_framework.schemas.openapi import AutoSchema +from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator from . import views @@ -18,7 +17,7 @@ def create_request(path): def create_view(view_cls, method, request): - generator = OpenAPISchemaGenerator() + generator = SchemaGenerator() view = generator.create_view(view_cls.as_view(), method, request) return view @@ -144,7 +143,7 @@ class TestGenerator(TestCase): patterns = [ url(r'^example/?$', views.ExampleListView.as_view()), ] - generator = OpenAPISchemaGenerator(patterns=patterns) + generator = SchemaGenerator(patterns=patterns) generator._initialise_endpoints() paths = generator.get_paths() @@ -160,7 +159,7 @@ class TestGenerator(TestCase): patterns = [ url(r'^example/?$', views.ExampleListView.as_view()), ] - generator = OpenAPISchemaGenerator(patterns=patterns) + generator = SchemaGenerator(patterns=patterns) request = create_request('/') schema = generator.get_schema(request=request)