From 18575c9f5f707828c54838c043ae7d4c53c9907c Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Wed, 6 Sep 2017 16:45:40 +0200 Subject: [PATCH] Split generators, inspectors, views. --- rest_framework/schemas/__init__.py | 807 +-------------------------- rest_framework/schemas/generators.py | 386 +++++++++++++ rest_framework/schemas/inspectors.py | 374 +++++++++++++ rest_framework/schemas/views.py | 47 ++ 4 files changed, 810 insertions(+), 804 deletions(-) create mode 100644 rest_framework/schemas/generators.py create mode 100644 rest_framework/schemas/inspectors.py create mode 100644 rest_framework/schemas/views.py diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index 0130451d3..308f958bb 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -1,135 +1,8 @@ -import re -from collections import OrderedDict -from importlib import import_module - -from django.conf import settings -from django.contrib.admindocs.views import simplify_regex -from django.core.exceptions import PermissionDenied -from django.db import models -from django.http import Http404 -from django.utils import six -from django.utils.encoding import force_text, smart_text -from django.utils.translation import ugettext_lazy as _ - -from rest_framework import exceptions, renderers, serializers -from rest_framework.compat import ( - RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate, - urlparse -) -from rest_framework.request import clone_request -from rest_framework.response import Response -from rest_framework.settings import api_settings -from rest_framework.utils import formatting -from rest_framework.utils.model_meta import _get_pk -from rest_framework.views import APIView - -header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') - - -def field_to_schema(field): - title = force_text(field.label) if field.label else '' - description = force_text(field.help_text) if field.help_text else '' - - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = field_to_schema(field.child) - return coreschema.Array( - items=child_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.Serializer): - return coreschema.Object( - properties=OrderedDict([ - (key, field_to_schema(value)) - for key, value - in field.fields.items() - ]), - title=title, - description=description - ) - elif isinstance(field, serializers.ManyRelatedField): - return coreschema.Array( - items=coreschema.String(), - title=title, - description=description - ) - elif isinstance(field, serializers.RelatedField): - return coreschema.String(title=title, description=description) - elif isinstance(field, serializers.MultipleChoiceField): - return coreschema.Array( - items=coreschema.Enum(enum=list(field.choices.keys())), - title=title, - description=description - ) - elif isinstance(field, serializers.ChoiceField): - return coreschema.Enum( - enum=list(field.choices.keys()), - title=title, - description=description - ) - elif isinstance(field, serializers.BooleanField): - return coreschema.Boolean(title=title, description=description) - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - return coreschema.Number(title=title, description=description) - elif isinstance(field, serializers.IntegerField): - return coreschema.Integer(title=title, description=description) - - if field.style.get('base_template') == 'textarea.html': - return coreschema.String( - title=title, - description=description, - format='textarea' - ) - return coreschema.String(title=title, description=description) - - -def common_path(paths): - split_paths = [path.strip('/').split('/') for path in paths] - s1 = min(split_paths) - s2 = max(split_paths) - common = s1 - for i, c in enumerate(s1): - if c != s2[i]: - common = s1[:i] - break - return '/' + '/'.join(common) - - -def get_pk_name(model): - meta = model._meta.concrete_model._meta - return _get_pk(meta).name - - -def is_api_view(callback): - """ - Return `True` if the given view callback is a REST framework view/viewset. - """ - cls = getattr(callback, 'cls', None) - return (cls is not None) and issubclass(cls, APIView) - - -def insert_into(target, keys, value): - """ - Nested dictionary insertion. - - >>> example = {} - >>> insert_into(example, ['a', 'b', 'c'], 123) - >>> example - {'a': {'b': {'c': 123}}} - """ - for key in keys[:-1]: - if key not in target: - target[key] = {} - target = target[key] - target[keys[-1]] = value - - -def is_custom_action(action): - return action not in set([ - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - ]) +# The API we expose +# from .views import get_schema_view +# Shared function. TODO: move to utils. def is_list_view(path, method, view): """ Return True if the given path/method appears to represent a list view. @@ -144,677 +17,3 @@ def is_list_view(path, method, view): if path_components and '{' in path_components[-1]: return False return True - - -def endpoint_ordering(endpoint): - path, method, callback = endpoint - method_priority = { - 'GET': 0, - 'POST': 1, - 'PUT': 2, - 'PATCH': 3, - 'DELETE': 4 - }.get(method, 5) - return (path, method_priority) - - -def get_pk_description(model, model_field): - if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') - elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') - else: - value_type = _('unique value') - - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, - ) - - -class EndpointInspector(object): - """ - A class to determine the available API endpoints that a project exposes. - """ - def __init__(self, patterns=None, urlconf=None): - if patterns is None: - if urlconf is None: - # Use the default Django URL conf - urlconf = settings.ROOT_URLCONF - - # Load the given URLconf module - if isinstance(urlconf, six.string_types): - urls = import_module(urlconf) - else: - urls = urlconf - patterns = urls.urlpatterns - - self.patterns = patterns - - def get_api_endpoints(self, patterns=None, prefix=''): - """ - Return a list of all available API endpoints by inspecting the URL conf. - """ - if patterns is None: - patterns = self.patterns - - api_endpoints = [] - - for pattern in patterns: - path_regex = prefix + pattern.regex.pattern - if isinstance(pattern, RegexURLPattern): - path = self.get_path_from_regex(path_regex) - callback = pattern.callback - if self.should_include_endpoint(path, callback): - for method in self.get_allowed_methods(callback): - endpoint = (path, method, callback) - api_endpoints.append(endpoint) - - elif isinstance(pattern, RegexURLResolver): - nested_endpoints = self.get_api_endpoints( - patterns=pattern.url_patterns, - prefix=path_regex - ) - api_endpoints.extend(nested_endpoints) - - api_endpoints = sorted(api_endpoints, key=endpoint_ordering) - - return api_endpoints - - def get_path_from_regex(self, path_regex): - """ - Given a URL conf regex, return a URI template string. - """ - path = simplify_regex(path_regex) - path = path.replace('<', '{').replace('>', '}') - return path - - def should_include_endpoint(self, path, callback): - """ - Return `True` if the given endpoint should be included. - """ - if not is_api_view(callback): - return False # Ignore anything except REST framework views. - - if path.endswith('.{format}') or path.endswith('.{format}/'): - return False # Ignore .json style URLs. - - return True - - def get_allowed_methods(self, callback): - """ - Return a list of the valid HTTP methods for this endpoint. - """ - if hasattr(callback, 'actions'): - actions = set(callback.actions.keys()) - http_method_names = set(callback.cls.http_method_names) - return [method.upper() for method in actions & http_method_names] - - return [ - method for method in - callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') - ] - - -class ViewInspector(object): - """ - Descriptor class on APIView. - - Provide subclass for per-view schema generation - """ - def __get__(self, instance, owner): - """ - Enables `ViewInspector` as a Python _Descriptor_. - - This is how `view.schema` knows about `view`. - - `__get__` is called when the descriptor is accessed on the owner. - (That will be when view.schema is called in our case.) - - `owner` is always the owner class. (An APIView, or subclass for us.) - `instance` is the view instance or `None` if accessed from the class, - rather than an instance. - - See: https://docs.python.org/3/howto/descriptor.html for info on - descriptor usage. - """ - self.view = instance - return self - - @property - def view(self): - """View property.""" - assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)" - return self._view - - @view.setter - def view(self, value): - self._view = value - - @view.deleter - def view(self): - self._view = None - - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - raise NotImplementedError(".get_link() must be overridden.") - - -class AutoSchema(ViewInspector): - """ - Default inspector for APIView - - Responsible for per-view instrospection and schema generation. - """ - def __init__(self, manual_fields=None): - """ - Parameters: - - * `manual_fields`: list of `coreapi.Field` instances that - will be added to auto-generated fields, overwriting on `Field.name` - """ - - self._manual_fields = manual_fields - - def get_link(self, path, method, base_url): - fields = self.get_path_fields(path, method) - fields += self.get_serializer_fields(path, method) - fields += self.get_pagination_fields(path, method) - fields += self.get_filter_fields(path, method) - - if self._manual_fields is not None: - by_name = {f.name: f for f in fields} - for f in self._manual_fields: - by_name[f.name] = f - fields = list(by_name.values()) - - if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method) - else: - encoding = None - - description = self.get_description(path, method) - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=urlparse.urljoin(base_url, path), - action=method.lower(), - encoding=encoding, - fields=fields, - description=description - ) - - def get_description(self, path, method): - """ - Determine a link description. - - This will be based on the method docstring if one exists, - or else the class docstring. - """ - view = self.view - - method_name = getattr(view, 'action', method.lower()) - method_docstring = getattr(view, method_name, None).__doc__ - if method_docstring: - # An explicit docstring on the method or action. - return formatting.dedent(smart_text(method_docstring)) - - description = view.get_view_description() - lines = [line.strip() for line in description.splitlines()] - current_section = '' - sections = {'': ''} - - for line in lines: - if header_regex.match(line): - current_section, seperator, lead = line.partition(':') - sections[current_section] = lead.strip() - else: - sections[current_section] += '\n' + line - - # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` - coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - header = getattr(view, 'action', method.lower()) - if header in sections: - return sections[header].strip() - if header in coerce_method_names: - if coerce_method_names[header] in sections: - return sections[coerce_method_names[header]].strip() - return sections[''].strip() - - def get_path_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - templated path variables. - """ - view = self.view - model = getattr(getattr(view, 'queryset', None), 'model', None) - fields = [] - - for variable in uritemplate.variables(path): - title = '' - description = '' - schema_cls = coreschema.String - kwargs = {} - if model is not None: - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except: - model_field = None - - if model_field is not None and model_field.verbose_name: - title = force_text(model_field.verbose_name) - - if model_field is not None and model_field.help_text: - description = force_text(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex - elif isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - - field = coreapi.Field( - name=variable, - location='path', - required=True, - schema=schema_cls(title=title, description=description, **kwargs) - ) - fields.append(field) - - return fields - - def get_serializer_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - request body input, as determined by the serializer class. - """ - view = self.view - - if method not in ('PUT', 'PATCH', 'POST'): - return [] - - if not hasattr(view, 'get_serializer'): - return [] - - serializer = view.get_serializer() - - if isinstance(serializer, serializers.ListSerializer): - return [ - coreapi.Field( - name='data', - location='body', - required=True, - schema=coreschema.Array() - ) - ] - - if not isinstance(serializer, serializers.Serializer): - return [] - - fields = [] - for field in serializer.fields.values(): - if field.read_only or isinstance(field, serializers.HiddenField): - continue - - required = field.required and method != 'PATCH' - field = coreapi.Field( - name=field.field_name, - location='form', - required=required, - schema=field_to_schema(field) - ) - fields.append(field) - - return fields - - def get_pagination_fields(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - pagination = getattr(view, 'pagination_class', None) - if not pagination: - return [] - - paginator = view.pagination_class() - return paginator.get_schema_fields(view) - - def get_filter_fields(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - if not getattr(view, 'filter_backends', None): - return [] - - fields = [] - for filter_backend in view.filter_backends: - fields += filter_backend().get_schema_fields(view) - return fields - - def get_encoding(self, path, method): - """ - Return the 'encoding' parameter to use for a given endpoint. - """ - view = self.view - - # Core API supports the following request encodings over HTTP... - supported_media_types = set(( - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', - )) - parser_classes = getattr(view, 'parser_classes', []) - for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) - if media_type in supported_media_types: - return media_type - # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' - - return None - -# Note: With `AutoSchema` defined we attach it to APIView. -# * We do this here to avoid the dependency cycle from SchemaView needing -# APIView (below). -# * This requires importing _something_ from `rest_framework.schemas` or -# `rest_framework.documentation` before `APIView.schema will be available. -# * ???: When would `APIView.schema` be needed and that NOT be the case? -# * The alternative is to import AutoSchema to `views`, make `schemas` a -# package, and move SchemaView to `schema.views`, importing APIView there. -APIView.schema = AutoSchema() - - -class ManualSchema(ViewInspector): - """ - Overrides get_link to return manually specified schema. - """ - def __init__(self, link): - assert isinstance(link, coreapi.Link) - self._link = link - - def get_link(self, *args): - return self._link - - -class SchemaGenerator(object): - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - endpoint_inspector_cls = EndpointInspector - - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - - # 'pk' isn't great as an externally exposed name for an identifier, - # so by default we prefer to use the actual model field name for schemas. - # Set by 'SCHEMA_COERCE_PATH_PK'. - coerce_path_pk = None - - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - - if url and not url.endswith('/'): - url += '/' - - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK - - self.patterns = patterns - self.urlconf = urlconf - self.title = title - self.description = description - self.url = url - self.endpoints = None - - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ - if self.endpoints is None: - inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) - self.endpoints = inspector.get_api_endpoints() - - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - def get_links(self, request=None): - """ - Return a dictionary containing all the links that should be - included in the API schema. - """ - links = OrderedDict() - - # Generate (path, method, view) given (path, method, callback). - paths = [] - view_endpoints = [] - for path, method, callback in self.endpoints: - view = self.create_view(callback, method, request) - if getattr(view, 'exclude_from_schema', False): - continue - path = self.coerce_path(path, method, view) - paths.append(path) - view_endpoints.append((path, method, view)) - - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) - - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) - return links - - # Methods used when we generate a view instance from the raw callback... - - def determine_path_prefix(self, paths): - """ - Given a list of all paths, return the common prefix which should be - discounted when generating a schema structure. - - This will be the longest common string that does not include that last - component of the URL, or the last component before a path parameter. - - For example: - - /api/v1/users/ - /api/v1/users/{pk}/ - - The path prefix is '/api/v1/' - """ - prefixes = [] - for path in paths: - components = path.strip('/').split('/') - initial_components = [] - for component in components: - if '{' in component: - break - initial_components.append(component) - prefix = '/'.join(initial_components[:-1]) - if not prefix: - # We can just break early in the case that there's at least - # one URL that doesn't have a path prefix. - return '/' - prefixes.append('/' + prefix + '/') - return common_path(prefixes) - - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - - def has_view_permissions(self, path, method, view): - """ - Return `True` if the incoming request has the correct view permissions. - """ - if view.request is None: - return True - - try: - view.check_permissions(view.request) - except (exceptions.APIException, Http404, PermissionDenied): - return False - return True - - def coerce_path(self, path, method, view): - """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") - """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) - - # Method for generating the link layout.... - - def get_keys(self, subpath, method, view): - """ - Return a list of keys that should be used to layout a link within - the schema document. - - /users/ ("users", "list"), ("users", "create") - /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") - /users/enabled/ ("users", "enabled") # custom viewset list action - /users/{pk}/star/ ("users", "star") # custom viewset detail action - /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") - /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") - """ - if hasattr(view, 'action'): - # Viewsets have explicitly named actions. - action = view.action - else: - # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' - else: - action = self.default_mapping[method.lower()] - - named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component - ] - - if is_custom_action(action): - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - if len(view.action_map) > 1: - action = self.default_mapping[method.lower()] - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - return named_path_components + [action] - else: - return named_path_components[:-1] + [action] - - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - - # Default action, eg "/users/", "/users/{pk}/" - return named_path_components + [action] - - -class SchemaView(APIView): - _ignore_model_permissions = True - exclude_from_schema = True - renderer_classes = None - schema_generator = None - public = False - - def __init__(self, *args, **kwargs): - super(SchemaView, self).__init__(*args, **kwargs) - if self.renderer_classes is None: - if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: - self.renderer_classes = [ - renderers.CoreJSONRenderer, - renderers.BrowsableAPIRenderer, - ] - else: - self.renderer_classes = [renderers.CoreJSONRenderer] - - def get(self, request, *args, **kwargs): - schema = self.schema_generator.get_schema(request, self.public) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - - -def get_schema_view( - title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=SchemaGenerator): - """ - Return a schema view. - """ - generator = generator_class( - title=title, url=url, description=description, - urlconf=urlconf, patterns=patterns, - ) - return SchemaView.as_view( - renderer_classes=renderer_classes, - schema_generator=generator, - public=public, - ) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py new file mode 100644 index 000000000..544042c9c --- /dev/null +++ b/rest_framework/schemas/generators.py @@ -0,0 +1,386 @@ +from collections import OrderedDict +from importlib import import_module + +from django.conf import settings +from django.contrib.admindocs.views import simplify_regex +from django.core.exceptions import PermissionDenied +from django.http import Http404 +from django.utils import six + +from rest_framework import exceptions +from rest_framework.compat import ( + RegexURLPattern, RegexURLResolver, coreapi, coreschema +) +from rest_framework.request import clone_request +from rest_framework.settings import api_settings +from rest_framework.utils.model_meta import _get_pk +from rest_framework.views import APIView + +from . import is_list_view + + +def common_path(paths): + split_paths = [path.strip('/').split('/') for path in paths] + s1 = min(split_paths) + s2 = max(split_paths) + common = s1 + for i, c in enumerate(s1): + if c != s2[i]: + common = s1[:i] + break + return '/' + '/'.join(common) + + +def get_pk_name(model): + meta = model._meta.concrete_model._meta + return _get_pk(meta).name + + +def is_api_view(callback): + """ + Return `True` if the given view callback is a REST framework view/viewset. + """ + cls = getattr(callback, 'cls', None) + return (cls is not None) and issubclass(cls, APIView) + + +def insert_into(target, keys, value): + """ + Nested dictionary insertion. + + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + {'a': {'b': {'c': 123}}} + """ + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value + + +def is_custom_action(action): + return action not in set([ + 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' + ]) + + +def endpoint_ordering(endpoint): + path, method, callback = endpoint + method_priority = { + 'GET': 0, + 'POST': 1, + 'PUT': 2, + 'PATCH': 3, + 'DELETE': 4 + }.get(method, 5) + return (path, method_priority) + + +class EndpointInspector(object): + """ + A class to determine the available API endpoints that a project exposes. + """ + def __init__(self, patterns=None, urlconf=None): + if patterns is None: + if urlconf is None: + # Use the default Django URL conf + urlconf = settings.ROOT_URLCONF + + # Load the given URLconf module + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) + else: + urls = urlconf + patterns = urls.urlpatterns + + self.patterns = patterns + + def get_api_endpoints(self, patterns=None, prefix=''): + """ + Return a list of all available API endpoints by inspecting the URL conf. + """ + if patterns is None: + patterns = self.patterns + + api_endpoints = [] + + for pattern in patterns: + path_regex = prefix + pattern.regex.pattern + if isinstance(pattern, RegexURLPattern): + path = self.get_path_from_regex(path_regex) + callback = pattern.callback + if self.should_include_endpoint(path, callback): + for method in self.get_allowed_methods(callback): + endpoint = (path, method, callback) + api_endpoints.append(endpoint) + + elif isinstance(pattern, RegexURLResolver): + nested_endpoints = self.get_api_endpoints( + patterns=pattern.url_patterns, + prefix=path_regex + ) + api_endpoints.extend(nested_endpoints) + + api_endpoints = sorted(api_endpoints, key=endpoint_ordering) + + return api_endpoints + + def get_path_from_regex(self, path_regex): + """ + Given a URL conf regex, return a URI template string. + """ + path = simplify_regex(path_regex) + path = path.replace('<', '{').replace('>', '}') + return path + + def should_include_endpoint(self, path, callback): + """ + Return `True` if the given endpoint should be included. + """ + if not is_api_view(callback): + return False # Ignore anything except REST framework views. + + if path.endswith('.{format}') or path.endswith('.{format}/'): + return False # Ignore .json style URLs. + + return True + + def get_allowed_methods(self, callback): + """ + Return a list of the valid HTTP methods for this endpoint. + """ + if hasattr(callback, 'actions'): + actions = set(callback.actions.keys()) + http_method_names = set(callback.cls.http_method_names) + return [method.upper() for method in actions & http_method_names] + + return [ + method for method in + callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') + ] + + +class SchemaGenerator(object): + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + endpoint_inspector_cls = EndpointInspector + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + # 'pk' isn't great as an externally exposed name for an identifier, + # so by default we prefer to use the actual model field name for schemas. + # Set by 'SCHEMA_COERCE_PATH_PK'. + coerce_path_pk = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + if url and not url.endswith('/'): + url += '/' + + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK + + self.patterns = patterns + self.urlconf = urlconf + self.title = title + self.description = description + self.url = url + self.endpoints = None + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + if self.endpoints is None: + inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) + self.endpoints = inspector.get_api_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ + links = OrderedDict() + + # Generate (path, method, view) given (path, method, callback). + paths = [] + view_endpoints = [] + for path, method, callback in self.endpoints: + view = self.create_view(callback, method, request) + if getattr(view, 'exclude_from_schema', False): + continue + path = self.coerce_path(path, method, view) + paths.append(path) + view_endpoints.append((path, method, view)) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + return links + + # Methods used when we generate a view instance from the raw callback... + + def determine_path_prefix(self, paths): + """ + Given a list of all paths, return the common prefix which should be + discounted when generating a schema structure. + + This will be the longest common string that does not include that last + component of the URL, or the last component before a path parameter. + + For example: + + /api/v1/users/ + /api/v1/users/{pk}/ + + The path prefix is '/api/v1/' + """ + prefixes = [] + for path in paths: + components = path.strip('/').split('/') + initial_components = [] + for component in components: + if '{' in component: + break + initial_components.append(component) + prefix = '/'.join(initial_components[:-1]) + if not prefix: + # We can just break early in the case that there's at least + # one URL that doesn't have a path prefix. + return '/' + prefixes.append('/' + prefix + '/') + return common_path(prefixes) + + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) + + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + + return view + + def has_view_permissions(self, path, method, view): + """ + Return `True` if the incoming request has the correct view permissions. + """ + if view.request is None: + return True + + try: + view.check_permissions(view.request) + except (exceptions.APIException, Http404, PermissionDenied): + return False + return True + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + + # Method for generating the link layout.... + + def get_keys(self, subpath, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if is_list_view(subpath, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in subpath.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] + + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py new file mode 100644 index 000000000..2a15f91e7 --- /dev/null +++ b/rest_framework/schemas/inspectors.py @@ -0,0 +1,374 @@ +import re +from collections import OrderedDict + +from django.db import models +from django.utils.encoding import force_text, smart_text +from django.utils.translation import ugettext_lazy as _ + +from rest_framework import serializers +from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse +from rest_framework.settings import api_settings +from rest_framework.utils import formatting + +from . import is_list_view + +header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + + +def field_to_schema(field): + title = force_text(field.label) if field.label else '' + description = force_text(field.help_text) if field.help_text else '' + + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = field_to_schema(field.child) + return coreschema.Array( + items=child_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.Serializer): + return coreschema.Object( + properties=OrderedDict([ + (key, field_to_schema(value)) + for key, value + in field.fields.items() + ]), + title=title, + description=description + ) + elif isinstance(field, serializers.ManyRelatedField): + return coreschema.Array( + items=coreschema.String(), + title=title, + description=description + ) + elif isinstance(field, serializers.RelatedField): + return coreschema.String(title=title, description=description) + elif isinstance(field, serializers.MultipleChoiceField): + return coreschema.Array( + items=coreschema.Enum(enum=list(field.choices.keys())), + title=title, + description=description + ) + elif isinstance(field, serializers.ChoiceField): + return coreschema.Enum( + enum=list(field.choices.keys()), + title=title, + description=description + ) + elif isinstance(field, serializers.BooleanField): + return coreschema.Boolean(title=title, description=description) + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + return coreschema.Number(title=title, description=description) + elif isinstance(field, serializers.IntegerField): + return coreschema.Integer(title=title, description=description) + + if field.style.get('base_template') == 'textarea.html': + return coreschema.String( + title=title, + description=description, + format='textarea' + ) + return coreschema.String(title=title, description=description) + + +def get_pk_description(model, model_field): + if isinstance(model_field, models.AutoField): + value_type = _('unique integer value') + elif isinstance(model_field, models.UUIDField): + value_type = _('UUID string') + else: + value_type = _('unique value') + + return _('A {value_type} identifying this {name}.').format( + value_type=value_type, + name=model._meta.verbose_name, + ) + + +class ViewInspector(object): + """ + Descriptor class on APIView. + + Provide subclass for per-view schema generation + """ + def __get__(self, instance, owner): + """ + Enables `ViewInspector` as a Python _Descriptor_. + + This is how `view.schema` knows about `view`. + + `__get__` is called when the descriptor is accessed on the owner. + (That will be when view.schema is called in our case.) + + `owner` is always the owner class. (An APIView, or subclass for us.) + `instance` is the view instance or `None` if accessed from the class, + rather than an instance. + + See: https://docs.python.org/3/howto/descriptor.html for info on + descriptor usage. + """ + self.view = instance + return self + + @property + def view(self): + """View property.""" + assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)" + return self._view + + @view.setter + def view(self, value): + self._view = value + + @view.deleter + def view(self): + self._view = None + + def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ + raise NotImplementedError(".get_link() must be overridden.") + + +class AutoSchema(ViewInspector): + """ + Default inspector for APIView + + Responsible for per-view instrospection and schema generation. + """ + def __init__(self, manual_fields=None): + """ + Parameters: + + * `manual_fields`: list of `coreapi.Field` instances that + will be added to auto-generated fields, overwriting on `Field.name` + """ + + self._manual_fields = manual_fields + + def get_link(self, path, method, base_url): + fields = self.get_path_fields(path, method) + fields += self.get_serializer_fields(path, method) + fields += self.get_pagination_fields(path, method) + fields += self.get_filter_fields(path, method) + + if self._manual_fields is not None: + by_name = {f.name: f for f in fields} + for f in self._manual_fields: + by_name[f.name] = f + fields = list(by_name.values()) + + if fields and any([field.location in ('form', 'body') for field in fields]): + encoding = self.get_encoding(path, method) + else: + encoding = None + + description = self.get_description(path, method) + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=urlparse.urljoin(base_url, path), + action=method.lower(), + encoding=encoding, + fields=fields, + description=description + ) + + def get_description(self, path, method): + """ + Determine a link description. + + This will be based on the method docstring if one exists, + or else the class docstring. + """ + view = self.view + + method_name = getattr(view, 'action', method.lower()) + method_docstring = getattr(view, method_name, None).__doc__ + if method_docstring: + # An explicit docstring on the method or action. + return formatting.dedent(smart_text(method_docstring)) + + description = view.get_view_description() + lines = [line.strip() for line in description.splitlines()] + current_section = '' + sections = {'': ''} + + for line in lines: + if header_regex.match(line): + current_section, seperator, lead = line.partition(':') + sections[current_section] = lead.strip() + else: + sections[current_section] += '\n' + line + + # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` + coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + header = getattr(view, 'action', method.lower()) + if header in sections: + return sections[header].strip() + if header in coerce_method_names: + if coerce_method_names[header] in sections: + return sections[coerce_method_names[header]].strip() + return sections[''].strip() + + def get_path_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + templated path variables. + """ + view = self.view + model = getattr(getattr(view, 'queryset', None), 'model', None) + fields = [] + + for variable in uritemplate.variables(path): + title = '' + description = '' + schema_cls = coreschema.String + kwargs = {} + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except: + model_field = None + + if model_field is not None and model_field.verbose_name: + title = force_text(model_field.verbose_name) + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: + kwargs['pattern'] = view.lookup_value_regex + elif isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + + field = coreapi.Field( + name=variable, + location='path', + required=True, + schema=schema_cls(title=title, description=description, **kwargs) + ) + fields.append(field) + + return fields + + def get_serializer_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + request body input, as determined by the serializer class. + """ + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return [] + + if not hasattr(view, 'get_serializer'): + return [] + + serializer = view.get_serializer() + + if isinstance(serializer, serializers.ListSerializer): + return [ + coreapi.Field( + name='data', + location='body', + required=True, + schema=coreschema.Array() + ) + ] + + if not isinstance(serializer, serializers.Serializer): + return [] + + fields = [] + for field in serializer.fields.values(): + if field.read_only or isinstance(field, serializers.HiddenField): + continue + + required = field.required and method != 'PATCH' + field = coreapi.Field( + name=field.field_name, + location='form', + required=required, + schema=field_to_schema(field) + ) + fields.append(field) + + return fields + + def get_pagination_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_fields(view) + + def get_filter_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + if not getattr(view, 'filter_backends', None): + return [] + + fields = [] + for filter_backend in view.filter_backends: + fields += filter_backend().get_schema_fields(view) + return fields + + def get_encoding(self, path, method): + """ + Return the 'encoding' parameter to use for a given endpoint. + """ + view = self.view + + # Core API supports the following request encodings over HTTP... + supported_media_types = set(( + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data', + )) + parser_classes = getattr(view, 'parser_classes', []) + for parser_class in parser_classes: + media_type = getattr(parser_class, 'media_type', None) + if media_type in supported_media_types: + return media_type + # Raw binary uploads are supported with "application/octet-stream" + if media_type == '*/*': + return 'application/octet-stream' + + return None + + +class ManualSchema(ViewInspector): + """ + Overrides get_link to return manually specified schema. + """ + def __init__(self, link): + assert isinstance(link, coreapi.Link) + self._link = link + + def get_link(self, *args): + return self._link diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py new file mode 100644 index 000000000..cac4c4ca7 --- /dev/null +++ b/rest_framework/schemas/views.py @@ -0,0 +1,47 @@ +from rest_framework import exceptions, renderers +from rest_framework.response import Response +from rest_framework.schemas.generators import SchemaGenerator +from rest_framework.settings import api_settings +from rest_framework.views import APIView + + +class SchemaView(APIView): + _ignore_model_permissions = True + exclude_from_schema = True + renderer_classes = None + schema_generator = None + public = False + + def __init__(self, *args, **kwargs): + super(SchemaView, self).__init__(*args, **kwargs) + if self.renderer_classes is None: + if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: + self.renderer_classes = [ + renderers.CoreJSONRenderer, + renderers.BrowsableAPIRenderer, + ] + else: + self.renderer_classes = [renderers.CoreJSONRenderer] + + def get(self, request, *args, **kwargs): + schema = self.schema_generator.get_schema(request, self.public) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + +def get_schema_view( + title=None, url=None, description=None, urlconf=None, renderer_classes=None, + public=False, patterns=None, generator_class=SchemaGenerator): + """ + Return a schema view. + """ + generator = generator_class( + title=title, url=url, description=description, + urlconf=urlconf, patterns=patterns, + ) + return SchemaView.as_view( + renderer_classes=renderer_classes, + schema_generator=generator, + public=public, + )