diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 18693bed0..383785020 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -241,35 +241,18 @@ class EndpointEnumerator(object): return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class SchemaGenerator(object): - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } +class BaseSchemaGenerator(object): endpoint_inspector_cls = EndpointEnumerator - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - if url and not url.endswith('/'): url += '/' - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.patterns = patterns @@ -279,55 +262,11 @@ class SchemaGenerator(object): self.url = url self.endpoints = None - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ - self._initialise_endpoints() - - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - def _initialise_endpoints(self): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() - def get_links(self, request=None): - """ - Return a dictionary containing all the links that should be - included in the API schema. - """ - links = LinkNode() - - paths, view_endpoints = self._get_paths_and_endpoints(request) - - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) - - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) - - return links - def _get_paths_and_endpoints(self, request): """ Generate (path, method, view) given (path, method, callback) for paths. @@ -342,7 +281,46 @@ class SchemaGenerator(object): return paths, view_endpoints - # Methods used when we generate a view instance from the raw callback... + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls(**getattr(callback, 'initkwargs', {})) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) + + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + + return view + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + + def get_schema(self, request=None, public=False): + raise NotImplementedError(".get_schema() must be implemented in subclasses.") def determine_path_prefix(self, paths): """ @@ -375,29 +353,6 @@ class SchemaGenerator(object): prefixes.append('/' + prefix + '/') return common_path(prefixes) - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. @@ -411,23 +366,77 @@ class SchemaGenerator(object): return False return True - def coerce_path(self, path, method, view): + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") + Return a dictionary containing all the links that should be + included in the API schema. """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) # Method for generating the link layout.... - def get_keys(self, subpath, method, view): """ Return a list of keys that should be used to layout a link within @@ -473,7 +482,7 @@ class SchemaGenerator(object): return named_path_components + [action] -class OpenAPISchemaGenerator(SchemaGenerator): +class OpenAPISchemaGenerator(BaseSchemaGenerator): def get_paths(self, request=None): result = OrderedDict()