From 37f210a455cc92cb3f61a23e194a1d0de58d149b Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Mon, 13 May 2019 16:07:03 +0200 Subject: [PATCH] Added OpenAPI Schema Generation. (#6532) Co-authored-by: Lucidiot Co-authored-by: dongfangtianyu --- rest_framework/filters.py | 29 + .../management/commands/generateschema.py | 49 +- rest_framework/pagination.py | 94 ++- rest_framework/renderers.py | 33 +- rest_framework/schemas/__init__.py | 18 +- rest_framework/schemas/coreapi.py | 616 ++++++++++++++++++ rest_framework/schemas/generators.py | 265 ++------ rest_framework/schemas/inspectors.py | 430 ------------ rest_framework/schemas/openapi.py | 377 +++++++++++ rest_framework/schemas/utils.py | 17 + rest_framework/schemas/views.py | 15 +- rest_framework/settings.py | 2 +- tests/schemas/__init__.py | 0 .../test_coreapi.py} | 98 +-- tests/schemas/test_get_schema_view.py | 20 + .../test_managementcommand.py} | 39 +- tests/schemas/test_openapi.py | 245 +++++++ tests/schemas/views.py | 58 ++ 18 files changed, 1671 insertions(+), 734 deletions(-) create mode 100644 rest_framework/schemas/coreapi.py create mode 100644 rest_framework/schemas/openapi.py create mode 100644 tests/schemas/__init__.py rename tests/{test_schemas.py => schemas/test_coreapi.py} (94%) create mode 100644 tests/schemas/test_get_schema_view.py rename tests/{test_generateschema.py => schemas/test_managementcommand.py} (57%) create mode 100644 tests/schemas/test_openapi.py create mode 100644 tests/schemas/views.py diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d5fe36964..e3b0468c7 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -37,6 +37,9 @@ class BaseFilterBackend: assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. @@ -156,6 +159,19 @@ class SearchFilter(BaseFilterBackend): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class OrderingFilter(BaseFilterBackend): # The URL query parameter used for the ordering. @@ -287,6 +303,19 @@ class OrderingFilter(BaseFilterBackend): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.ordering_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.ordering_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class DjangoObjectPermissionsFilter(BaseFilterBackend): """ diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 40909bd04..631f40290 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -1,41 +1,56 @@ from django.core.management.base import BaseCommand -from rest_framework.compat import coreapi -from rest_framework.renderers import ( - CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer -) -from rest_framework.schemas.generators import SchemaGenerator +from rest_framework import renderers +from rest_framework.schemas import coreapi +from rest_framework.schemas.openapi import SchemaGenerator + +OPENAPI_MODE = 'openapi' +COREAPI_MODE = 'coreapi' class Command(BaseCommand): help = "Generates configured API schema for project." + def get_mode(self): + return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE + def add_arguments(self, parser): - parser.add_argument('--title', dest="title", default=None, type=str) + parser.add_argument('--title', dest="title", default='', type=str) parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str) - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + if self.get_mode() == COREAPI_MODE: + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + else: + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) def handle(self, *args, **options): - assert coreapi is not None, 'coreapi must be installed.' - - generator = SchemaGenerator( + generator_class = self.get_generator_class() + generator = generator_class( url=options['url'], title=options['title'], description=options['description'] ) - schema = generator.get_schema(request=None, public=True) - renderer = self.get_renderer(options['format']) output = renderer.render(schema, renderer_context={}) self.stdout.write(output.decode()) def get_renderer(self, format): - renderer_cls = { - 'corejson': CoreJSONRenderer, - 'openapi': OpenAPIRenderer, - 'openapi-json': JSONOpenAPIRenderer, - }[format] + if self.get_mode() == COREAPI_MODE: + renderer_cls = { + 'corejson': renderers.CoreJSONRenderer, + 'openapi': renderers.CoreAPIOpenAPIRenderer, + 'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer, + }[format] + return renderer_cls() + renderer_cls = { + 'openapi': renderers.OpenAPIRenderer, + 'openapi-json': renderers.JSONOpenAPIRenderer, + }[format] return renderer_cls() + + def get_generator_class(self): + if self.get_mode() == COREAPI_MODE: + return coreapi.SchemaGenerator + return SchemaGenerator diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 0b2877a45..38d6b9e1c 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -148,6 +148,9 @@ class BasePagination: assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class PageNumberPagination(BasePagination): """ @@ -301,6 +304,32 @@ class PageNumberPagination(BasePagination): ) return fields + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.page_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ) + return parameters + class LimitOffsetPagination(BasePagination): """ @@ -430,6 +459,15 @@ class LimitOffsetPagination(BasePagination): context = self.get_html_context() return template.render(context) + def get_count(self, queryset): + """ + Determine an object count, supporting either querysets or regular lists. + """ + try: + return queryset.count() + except (AttributeError, TypeError): + return len(queryset) + def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' @@ -454,14 +492,28 @@ class LimitOffsetPagination(BasePagination): ) ] - def get_count(self, queryset): - """ - Determine an object count, supporting either querysets or regular lists. - """ - try: - return queryset.count() - except (AttributeError, TypeError): - return len(queryset) + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.limit_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.limit_query_description), + 'schema': { + 'type': 'integer', + }, + }, + { + 'name': self.offset_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.offset_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + return parameters class CursorPagination(BasePagination): @@ -816,3 +868,29 @@ class CursorPagination(BasePagination): ) ) return fields + + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.cursor_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.cursor_query_description), + 'schema': { + 'type': 'integer', + }, + } + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + } + ) + return parameters diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 143d1b7e7..2a4ae5905 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -1013,28 +1013,49 @@ class _BaseOpenAPIRenderer: } -class OpenAPIRenderer(_BaseOpenAPIRenderer): +class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer): media_type = 'application/vnd.oai.openapi' charset = None format = 'openapi' def __init__(self): - assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' - assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' + assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.' + assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.' def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) return yaml.dump(structure, default_flow_style=False).encode() -class JSONOpenAPIRenderer(_BaseOpenAPIRenderer): +class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer): media_type = 'application/vnd.oai.openapi+json' charset = None format = 'openapi-json' def __init__(self): - assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.' + assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.' def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) - return json.dumps(structure, indent=4).encode() + return json.dumps(structure, indent=4).encode('utf-8') + + +class OpenAPIRenderer(BaseRenderer): + media_type = 'application/vnd.oai.openapi' + charset = None + format = 'openapi' + + def __init__(self): + assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' + + def render(self, data, media_type=None, renderer_context=None): + return yaml.dump(data, default_flow_style=False).encode('utf-8') + + +class JSONOpenAPIRenderer(BaseRenderer): + media_type = 'application/vnd.oai.openapi+json' + charset = None + format = 'openapi-json' + + def render(self, data, media_type=None, renderer_context=None): + return json.dumps(data, indent=2).encode('utf-8') diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index ba0ec6536..8fdb2d86a 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -22,24 +22,32 @@ Other access should target the submodules directly """ from rest_framework.settings import api_settings -from .generators import SchemaGenerator -from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa +from . import coreapi, openapi +from .inspectors import DefaultSchema # noqa +from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa def get_schema_view( title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=SchemaGenerator, + public=False, patterns=None, generator_class=None, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): """ Return a schema view. """ - # Avoid import cycle on APIView - from .views import SchemaView + if generator_class is None: + if coreapi.is_enabled(): + generator_class = coreapi.SchemaGenerator + else: + generator_class = openapi.SchemaGenerator + generator = generator_class( title=title, url=url, description=description, urlconf=urlconf, patterns=patterns, ) + + # Avoid import cycle on APIView + from .views import SchemaView return SchemaView.as_view( renderer_classes=renderer_classes, schema_generator=generator, diff --git a/rest_framework/schemas/coreapi.py b/rest_framework/schemas/coreapi.py new file mode 100644 index 000000000..5cf789f9f --- /dev/null +++ b/rest_framework/schemas/coreapi.py @@ -0,0 +1,616 @@ +import re +import warnings +from collections import Counter, OrderedDict +from urllib import parse + +from django.db import models +from django.utils.encoding import force_text, smart_text + +from rest_framework import exceptions, serializers +from rest_framework.compat import coreapi, coreschema, uritemplate +from rest_framework.settings import api_settings +from rest_framework.utils import formatting + +from .generators import BaseSchemaGenerator +from .inspectors import ViewInspector +from .utils import get_pk_description, is_list_view + +# Used in _get_description_section() +# TODO: ???: move up to base. +header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + +# Generator # +# TODO: Pull some of this into base. + + +def is_custom_action(action): + return action not in { + 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' + } + + +def distribute_links(obj): + for key, value in obj.items(): + distribute_links(value) + + for preferred_key, link in obj.links: + key = obj.get_available_key(preferred_key) + obj[key] = link + + +INSERT_INTO_COLLISION_FMT = """ +Schema Naming Collision. + +coreapi.Link for URL path {value_url} cannot be inserted into schema. +Position conflicts with coreapi.Link for URL path {target_url}. + +Attempted to insert link with keys: {keys}. + +Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` +to customise schema structure. +""" + + +class LinkNode(OrderedDict): + def __init__(self): + self.links = [] + self.methods_counter = Counter() + super(LinkNode, self).__init__() + + def get_available_key(self, preferred_key): + if preferred_key not in self: + return preferred_key + + while True: + current_val = self.methods_counter[preferred_key] + self.methods_counter[preferred_key] += 1 + + key = '{}_{}'.format(preferred_key, current_val) + if key not in self: + return key + + +def insert_into(target, keys, value): + """ + Nested dictionary insertion. + + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) + """ + for key in keys[:-1]: + if key not in target: + target[key] = LinkNode() + target = target[key] + + try: + target.links.append((keys[-1], value)) + except TypeError: + msg = INSERT_INTO_COLLISION_FMT.format( + value_url=value.url, + target_url=target.url, + keys=keys + ) + raise ValueError(msg) + + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) + + # Method for generating the link layout.... + def get_keys(self, subpath, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if is_list_view(subpath, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in subpath.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] + + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] + +# View Inspectors # + + +def field_to_schema(field): + title = force_text(field.label) if field.label else '' + description = force_text(field.help_text) if field.help_text else '' + + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = field_to_schema(field.child) + return coreschema.Array( + items=child_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.DictField): + return coreschema.Object( + title=title, + description=description + ) + elif isinstance(field, serializers.Serializer): + return coreschema.Object( + properties=OrderedDict([ + (key, field_to_schema(value)) + for key, value + in field.fields.items() + ]), + title=title, + description=description + ) + elif isinstance(field, serializers.ManyRelatedField): + related_field_schema = field_to_schema(field.child_relation) + + return coreschema.Array( + items=related_field_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.PrimaryKeyRelatedField): + schema_cls = coreschema.String + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + return schema_cls(title=title, description=description) + elif isinstance(field, serializers.RelatedField): + return coreschema.String(title=title, description=description) + elif isinstance(field, serializers.MultipleChoiceField): + return coreschema.Array( + items=coreschema.Enum(enum=list(field.choices)), + title=title, + description=description + ) + elif isinstance(field, serializers.ChoiceField): + return coreschema.Enum( + enum=list(field.choices), + title=title, + description=description + ) + elif isinstance(field, serializers.BooleanField): + return coreschema.Boolean(title=title, description=description) + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + return coreschema.Number(title=title, description=description) + elif isinstance(field, serializers.IntegerField): + return coreschema.Integer(title=title, description=description) + elif isinstance(field, serializers.DateField): + return coreschema.String( + title=title, + description=description, + format='date' + ) + elif isinstance(field, serializers.DateTimeField): + return coreschema.String( + title=title, + description=description, + format='date-time' + ) + elif isinstance(field, serializers.JSONField): + return coreschema.Object(title=title, description=description) + + if field.style.get('base_template') == 'textarea.html': + return coreschema.String( + title=title, + description=description, + format='textarea' + ) + + return coreschema.String(title=title, description=description) + + +class AutoSchema(ViewInspector): + """ + Default inspector for APIView + + Responsible for per-view introspection and schema generation. + """ + def __init__(self, manual_fields=None): + """ + Parameters: + + * `manual_fields`: list of `coreapi.Field` instances that + will be added to auto-generated fields, overwriting on `Field.name` + """ + super(AutoSchema, self).__init__() + if manual_fields is None: + manual_fields = [] + self._manual_fields = manual_fields + + def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ + fields = self.get_path_fields(path, method) + fields += self.get_serializer_fields(path, method) + fields += self.get_pagination_fields(path, method) + fields += self.get_filter_fields(path, method) + + manual_fields = self.get_manual_fields(path, method) + fields = self.update_fields(fields, manual_fields) + + if fields and any([field.location in ('form', 'body') for field in fields]): + encoding = self.get_encoding(path, method) + else: + encoding = None + + description = self.get_description(path, method) + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=parse.urljoin(base_url, path), + action=method.lower(), + encoding=encoding, + fields=fields, + description=description + ) + + def get_description(self, path, method): + """ + Determine a link description. + + This will be based on the method docstring if one exists, + or else the class docstring. + """ + view = self.view + + method_name = getattr(view, 'action', method.lower()) + method_docstring = getattr(view, method_name, None).__doc__ + if method_docstring: + # An explicit docstring on the method or action. + return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) + else: + return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) + + def _get_description_section(self, view, header, description): + lines = [line for line in description.splitlines()] + current_section = '' + sections = {'': ''} + + for line in lines: + if header_regex.match(line): + current_section, seperator, lead = line.partition(':') + sections[current_section] = lead.strip() + else: + sections[current_section] += '\n' + line + + # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` + coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + if header in sections: + return sections[header].strip() + if header in coerce_method_names: + if coerce_method_names[header] in sections: + return sections[coerce_method_names[header]].strip() + return sections[''].strip() + + def get_path_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + templated path variables. + """ + view = self.view + model = getattr(getattr(view, 'queryset', None), 'model', None) + fields = [] + + for variable in uritemplate.variables(path): + title = '' + description = '' + schema_cls = coreschema.String + kwargs = {} + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.verbose_name: + title = force_text(model_field.verbose_name) + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: + kwargs['pattern'] = view.lookup_value_regex + elif isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + + field = coreapi.Field( + name=variable, + location='path', + required=True, + schema=schema_cls(title=title, description=description, **kwargs) + ) + fields.append(field) + + return fields + + def get_serializer_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + request body input, as determined by the serializer class. + """ + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return [] + + if not hasattr(view, 'get_serializer'): + return [] + + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if isinstance(serializer, serializers.ListSerializer): + return [ + coreapi.Field( + name='data', + location='body', + required=True, + schema=coreschema.Array() + ) + ] + + if not isinstance(serializer, serializers.Serializer): + return [] + + fields = [] + for field in serializer.fields.values(): + if field.read_only or isinstance(field, serializers.HiddenField): + continue + + required = field.required and method != 'PATCH' + field = coreapi.Field( + name=field.field_name, + location='form', + required=required, + schema=field_to_schema(field) + ) + fields.append(field) + + return fields + + def get_pagination_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_fields(view) + + def _allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + + Override to adjust behaviour for your view. + + Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) + to allow changes based on user experience. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + + return method.lower() in ["get", "put", "patch", "delete"] + + def get_filter_fields(self, path, method): + if not self._allows_filters(path, method): + return [] + + fields = [] + for filter_backend in self.view.filter_backends: + fields += filter_backend().get_schema_fields(self.view) + return fields + + def get_manual_fields(self, path, method): + return self._manual_fields + + @staticmethod + def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + if not update_with: + return fields + + by_name = OrderedDict((f.name, f) for f in fields) + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + + def get_encoding(self, path, method): + """ + Return the 'encoding' parameter to use for a given endpoint. + """ + view = self.view + + # Core API supports the following request encodings over HTTP... + supported_media_types = { + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data', + } + parser_classes = getattr(view, 'parser_classes', []) + for parser_class in parser_classes: + media_type = getattr(parser_class, 'media_type', None) + if media_type in supported_media_types: + return media_type + # Raw binary uploads are supported with "application/octet-stream" + if media_type == '*/*': + return 'application/octet-stream' + + return None + + +class ManualSchema(ViewInspector): + """ + Allows providing a list of coreapi.Fields, + plus an optional description. + """ + def __init__(self, fields, description='', encoding=None): + """ + Parameters: + + * `fields`: list of `coreapi.Field` instances. + * `description`: String description for view. Optional. + """ + super(ManualSchema, self).__init__() + assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" + self._fields = fields + self._description = description + self._encoding = encoding + + def get_link(self, path, method, base_url): + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=parse.urljoin(base_url, path), + action=method.lower(), + encoding=self._encoding, + fields=self._fields, + description=self._description + ) + + +def is_enabled(): + """Is CoreAPI Mode enabled?""" + return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 66afcca94..ecb07f935 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -4,7 +4,6 @@ generators.py # Top-down schema generation See schemas.__init__.py for package overview. """ import re -from collections import Counter, OrderedDict from importlib import import_module from django.conf import settings @@ -13,15 +12,11 @@ from django.core.exceptions import PermissionDenied from django.http import Http404 from rest_framework import exceptions -from rest_framework.compat import ( - URLPattern, URLResolver, coreapi, coreschema, get_original_route -) +from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.request import clone_request from rest_framework.settings import api_settings from rest_framework.utils.model_meta import _get_pk -from .utils import is_list_view - def common_path(paths): split_paths = [path.strip('/').split('/') for path in paths] @@ -50,78 +45,6 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -INSERT_INTO_COLLISION_FMT = """ -Schema Naming Collision. - -coreapi.Link for URL path {value_url} cannot be inserted into schema. -Position conflicts with coreapi.Link for URL path {target_url}. - -Attempted to insert link with keys: {keys}. - -Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` -to customise schema structure. -""" - - -class LinkNode(OrderedDict): - def __init__(self): - self.links = [] - self.methods_counter = Counter() - super().__init__() - - def get_available_key(self, preferred_key): - if preferred_key not in self: - return preferred_key - - while True: - current_val = self.methods_counter[preferred_key] - self.methods_counter[preferred_key] += 1 - - key = '{}_{}'.format(preferred_key, current_val) - if key not in self: - return key - - -def insert_into(target, keys, value): - """ - Nested dictionary insertion. - - >>> example = {} - >>> insert_into(example, ['a', 'b', 'c'], 123) - >>> example - LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) - """ - for key in keys[:-1]: - if key not in target: - target[key] = LinkNode() - target = target[key] - - try: - target.links.append((keys[-1], value)) - except TypeError: - msg = INSERT_INTO_COLLISION_FMT.format( - value_url=value.url, - target_url=target.url, - keys=keys - ) - raise ValueError(msg) - - -def distribute_links(obj): - for key, value in obj.items(): - distribute_links(value) - - for preferred_key, link in obj.links: - key = obj.get_available_key(preferred_key) - obj[key] = link - - -def is_custom_action(action): - return action not in { - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - } - - def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -190,6 +113,10 @@ class EndpointEnumerator: """ Given a URL conf regex, return a URI template string. """ + # ???: Would it be feasible to adjust this such that we generate the + # path, plus the kwargs, plus the type from the convertor, such that we + # could feed that straight into the parameter schema object? + path = simplify_regex(path_regex) # Strip Django 2.0 convertors as they are incompatible with uritemplate format @@ -228,35 +155,18 @@ class EndpointEnumerator: return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class SchemaGenerator: - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } +class BaseSchemaGenerator(object): endpoint_inspector_cls = EndpointEnumerator - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - if url and not url.endswith('/'): url += '/' - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.patterns = patterns @@ -266,36 +176,15 @@ class SchemaGenerator: self.url = url self.endpoints = None - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ + def _initialise_endpoints(self): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - def get_links(self, request=None): + def _get_paths_and_endpoints(self, request): """ - Return a dictionary containing all the links that should be - included in the API schema. + Generate (path, method, view) given (path, method, callback) for paths. """ - links = LinkNode() - - # Generate (path, method, view) given (path, method, callback). paths = [] view_endpoints = [] for path, method, callback in self.endpoints: @@ -304,22 +193,48 @@ class SchemaGenerator: paths.append(path) view_endpoints.append((path, method, view)) - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) + return paths, view_endpoints - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls(**getattr(callback, 'initkwargs', {})) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) - return links + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) - # Methods used when we generate a view instance from the raw callback... + if request is not None: + view.request = clone_request(request, method) + + return view + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + + def get_schema(self, request=None, public=False): + raise NotImplementedError(".get_schema() must be implemented in subclasses.") def determine_path_prefix(self, paths): """ @@ -352,29 +267,6 @@ class SchemaGenerator: prefixes.append('/' + prefix + '/') return common_path(prefixes) - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. @@ -387,64 +279,3 @@ class SchemaGenerator: except (exceptions.APIException, Http404, PermissionDenied): return False return True - - def coerce_path(self, path, method, view): - """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") - """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) - - # Method for generating the link layout.... - - def get_keys(self, subpath, method, view): - """ - Return a list of keys that should be used to layout a link within - the schema document. - - /users/ ("users", "list"), ("users", "create") - /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") - /users/enabled/ ("users", "enabled") # custom viewset list action - /users/{pk}/star/ ("users", "star") # custom viewset detail action - /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") - /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") - """ - if hasattr(view, 'action'): - # Viewsets have explicitly named actions. - action = view.action - else: - # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' - else: - action = self.default_mapping[method.lower()] - - named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component - ] - - if is_custom_action(action): - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - if len(view.action_map) > 1: - action = self.default_mapping[method.lower()] - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - return named_path_components + [action] - else: - return named_path_components[:-1] + [action] - - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - - # Default action, eg "/users/", "/users/{pk}/" - return named_path_components + [action] diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 2858c8c5b..86fcdc435 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -3,125 +3,9 @@ inspectors.py # Per-endpoint view introspection See schemas.__init__.py for package overview. """ -import re -import warnings -from collections import OrderedDict -from urllib import parse from weakref import WeakKeyDictionary -from django.db import models -from django.utils.encoding import force_text, smart_text -from django.utils.translation import gettext_lazy as _ - -from rest_framework import exceptions, serializers -from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.settings import api_settings -from rest_framework.utils import formatting - -from .utils import is_list_view - -header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') - - -def field_to_schema(field): - title = force_text(field.label) if field.label else '' - description = force_text(field.help_text) if field.help_text else '' - - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = field_to_schema(field.child) - return coreschema.Array( - items=child_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.DictField): - return coreschema.Object( - title=title, - description=description - ) - elif isinstance(field, serializers.Serializer): - return coreschema.Object( - properties=OrderedDict([ - (key, field_to_schema(value)) - for key, value - in field.fields.items() - ]), - title=title, - description=description - ) - elif isinstance(field, serializers.ManyRelatedField): - related_field_schema = field_to_schema(field.child_relation) - - return coreschema.Array( - items=related_field_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.PrimaryKeyRelatedField): - schema_cls = coreschema.String - model = getattr(field.queryset, 'model', None) - if model is not None: - model_field = model._meta.pk - if isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - return schema_cls(title=title, description=description) - elif isinstance(field, serializers.RelatedField): - return coreschema.String(title=title, description=description) - elif isinstance(field, serializers.MultipleChoiceField): - return coreschema.Array( - items=coreschema.Enum(enum=list(field.choices)), - title=title, - description=description - ) - elif isinstance(field, serializers.ChoiceField): - return coreschema.Enum( - enum=list(field.choices), - title=title, - description=description - ) - elif isinstance(field, serializers.BooleanField): - return coreschema.Boolean(title=title, description=description) - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - return coreschema.Number(title=title, description=description) - elif isinstance(field, serializers.IntegerField): - return coreschema.Integer(title=title, description=description) - elif isinstance(field, serializers.DateField): - return coreschema.String( - title=title, - description=description, - format='date' - ) - elif isinstance(field, serializers.DateTimeField): - return coreschema.String( - title=title, - description=description, - format='date-time' - ) - elif isinstance(field, serializers.JSONField): - return coreschema.Object(title=title, description=description) - - if field.style.get('base_template') == 'textarea.html': - return coreschema.String( - title=title, - description=description, - format='textarea' - ) - - return coreschema.String(title=title, description=description) - - -def get_pk_description(model, model_field): - if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') - elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') - else: - value_type = _('unique value') - - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, - ) class ViewInspector: @@ -178,320 +62,6 @@ class ViewInspector: def view(self): self._view = None - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - raise NotImplementedError(".get_link() must be overridden.") - - -class AutoSchema(ViewInspector): - """ - Default inspector for APIView - - Responsible for per-view introspection and schema generation. - """ - def __init__(self, manual_fields=None): - """ - Parameters: - - * `manual_fields`: list of `coreapi.Field` instances that - will be added to auto-generated fields, overwriting on `Field.name` - """ - super().__init__() - if manual_fields is None: - manual_fields = [] - self._manual_fields = manual_fields - - def get_link(self, path, method, base_url): - fields = self.get_path_fields(path, method) - fields += self.get_serializer_fields(path, method) - fields += self.get_pagination_fields(path, method) - fields += self.get_filter_fields(path, method) - - manual_fields = self.get_manual_fields(path, method) - fields = self.update_fields(fields, manual_fields) - - if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method) - else: - encoding = None - - description = self.get_description(path, method) - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=parse.urljoin(base_url, path), - action=method.lower(), - encoding=encoding, - fields=fields, - description=description - ) - - def get_description(self, path, method): - """ - Determine a link description. - - This will be based on the method docstring if one exists, - or else the class docstring. - """ - view = self.view - - method_name = getattr(view, 'action', method.lower()) - method_docstring = getattr(view, method_name, None).__doc__ - if method_docstring: - # An explicit docstring on the method or action. - return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) - else: - return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) - - def _get_description_section(self, view, header, description): - lines = [line for line in description.splitlines()] - current_section = '' - sections = {'': ''} - - for line in lines: - if header_regex.match(line): - current_section, seperator, lead = line.partition(':') - sections[current_section] = lead.strip() - else: - sections[current_section] += '\n' + line - - # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` - coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - if header in sections: - return sections[header].strip() - if header in coerce_method_names: - if coerce_method_names[header] in sections: - return sections[coerce_method_names[header]].strip() - return sections[''].strip() - - def get_path_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - templated path variables. - """ - view = self.view - model = getattr(getattr(view, 'queryset', None), 'model', None) - fields = [] - - for variable in uritemplate.variables(path): - title = '' - description = '' - schema_cls = coreschema.String - kwargs = {} - if model is not None: - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except Exception: - model_field = None - - if model_field is not None and model_field.verbose_name: - title = force_text(model_field.verbose_name) - - if model_field is not None and model_field.help_text: - description = force_text(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex - elif isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - - field = coreapi.Field( - name=variable, - location='path', - required=True, - schema=schema_cls(title=title, description=description, **kwargs) - ) - fields.append(field) - - return fields - - def get_serializer_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - request body input, as determined by the serializer class. - """ - view = self.view - - if method not in ('PUT', 'PATCH', 'POST'): - return [] - - if not hasattr(view, 'get_serializer'): - return [] - - try: - serializer = view.get_serializer() - except exceptions.APIException: - serializer = None - warnings.warn('{}.get_serializer() raised an exception during ' - 'schema generation. Serializer fields will not be ' - 'generated for {} {}.' - .format(view.__class__.__name__, method, path)) - - if isinstance(serializer, serializers.ListSerializer): - return [ - coreapi.Field( - name='data', - location='body', - required=True, - schema=coreschema.Array() - ) - ] - - if not isinstance(serializer, serializers.Serializer): - return [] - - fields = [] - for field in serializer.fields.values(): - if field.read_only or isinstance(field, serializers.HiddenField): - continue - - required = field.required and method != 'PATCH' - field = coreapi.Field( - name=field.field_name, - location='form', - required=required, - schema=field_to_schema(field) - ) - fields.append(field) - - return fields - - def get_pagination_fields(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - pagination = getattr(view, 'pagination_class', None) - if not pagination: - return [] - - paginator = view.pagination_class() - return paginator.get_schema_fields(view) - - def _allows_filters(self, path, method): - """ - Determine whether to include filter Fields in schema. - - Default implementation looks for ModelViewSet or GenericAPIView - actions/methods that cause filtering on the default implementation. - - Override to adjust behaviour for your view. - - Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) - to allow changes based on user experience. - """ - if getattr(self.view, 'filter_backends', None) is None: - return False - - if hasattr(self.view, 'action'): - return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] - - return method.lower() in ["get", "put", "patch", "delete"] - - def get_filter_fields(self, path, method): - if not self._allows_filters(path, method): - return [] - - fields = [] - for filter_backend in self.view.filter_backends: - fields += filter_backend().get_schema_fields(self.view) - return fields - - def get_manual_fields(self, path, method): - return self._manual_fields - - @staticmethod - def update_fields(fields, update_with): - """ - Update list of coreapi.Field instances, overwriting on `Field.name`. - - Utility function to handle replacing coreapi.Field fields - from a list by name. Used to handle `manual_fields`. - - Parameters: - - * `fields`: list of `coreapi.Field` instances to update - * `update_with: list of `coreapi.Field` instances to add or replace. - """ - if not update_with: - return fields - - by_name = OrderedDict((f.name, f) for f in fields) - for f in update_with: - by_name[f.name] = f - return list(by_name.values()) - - def get_encoding(self, path, method): - """ - Return the 'encoding' parameter to use for a given endpoint. - """ - view = self.view - - # Core API supports the following request encodings over HTTP... - supported_media_types = { - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', - } - parser_classes = getattr(view, 'parser_classes', []) - for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) - if media_type in supported_media_types: - return media_type - # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' - - return None - - -class ManualSchema(ViewInspector): - """ - Allows providing a list of coreapi.Fields, - plus an optional description. - """ - def __init__(self, fields, description='', encoding=None): - """ - Parameters: - - * `fields`: list of `coreapi.Field` instances. - * `description`: String description for view. Optional. - """ - super().__init__() - assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" - self._fields = fields - self._description = description - self._encoding = encoding - - def get_link(self, path, method, base_url): - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=parse.urljoin(base_url, path), - action=method.lower(), - encoding=self._encoding, - fields=self._fields, - description=self._description - ) - class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py new file mode 100644 index 000000000..44b281be8 --- /dev/null +++ b/rest_framework/schemas/openapi.py @@ -0,0 +1,377 @@ +import warnings + +from django.db import models +from django.utils.encoding import force_text + +from rest_framework import exceptions, serializers +from rest_framework.compat import uritemplate + +from .generators import BaseSchemaGenerator +from .inspectors import ViewInspector +from .utils import get_pk_description, is_list_view + +# Generator + + +class SchemaGenerator(BaseSchemaGenerator): + + def get_info(self): + info = { + 'title': self.title, + 'version': 'TODO', + } + + if self.description is not None: + info['description'] = self.description + + return info + + def get_paths(self, request=None): + result = {} + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + operation = view.schema.get_operation(path, method) + subpath = '/' + path[len(prefix):] + result.setdefault(subpath, {}) + result[subpath][method.lower()] = operation + + return result + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + + paths = self.get_paths(None if public else request) + if not paths: + return None + + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + return schema + +# View Inspectors + + +class AutoSchema(ViewInspector): + + content_types = ['application/json'] + method_mapping = { + 'get': 'Retrieve', + 'post': 'Create', + 'put': 'Update', + 'patch': 'PartialUpdate', + 'delete': 'Destroy', + } + + def get_operation(self, path, method): + operation = {} + + operation['operationId'] = self._get_operation_id(path, method) + + parameters = [] + parameters += self._get_path_parameters(path, method) + parameters += self._get_pagination_parameters(path, method) + parameters += self._get_filter_parameters(path, method) + operation['parameters'] = parameters + + request_body = self._get_request_body(path, method) + if request_body: + operation['requestBody'] = request_body + operation['responses'] = self._get_responses(path, method) + + return operation + + def _get_operation_id(self, path, method): + """ + Compute an operation ID from the model, serializer or view name. + """ + method_name = getattr(self.view, 'action', method.lower()) + if is_list_view(path, method, self.view): + action = 'List' + elif method_name not in self.method_mapping: + action = method_name + else: + action = self.method_mapping[method.lower()] + + # Try to deduce the ID from the view's model + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + if model is not None: + name = model.__name__ + + # Try with the serializer class name + elif hasattr(self.view, 'get_serializer_class'): + name = self.view.get_serializer_class().__name__ + if name.endswith('Serializer'): + name = name[:-10] + + # Fallback to the view name + else: + name = self.view.__class__.__name__ + if name.endswith('APIView'): + name = name[:-7] + elif name.endswith('View'): + name = name[:-4] + if name.endswith(action): # ListView, UpdateAPIView, ThingDelete ... + name = name[:-len(action)] + + if action == 'List' and not name.endswith('s'): # ListThings instead of ListThing + name += 's' + + return action + name + + def _get_path_parameters(self, path, method): + """ + Return a list of parameters from templated path variables. + """ + assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' + + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + parameters = [] + + for variable in uritemplate.variables(path): + description = '' + if model is not None: # TODO: test this. + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + parameter = { + "name": variable, + "in": "path", + "required": True, + "description": description, + 'schema': { + 'type': 'string', # TODO: integer, pattern, ... + }, + } + parameters.append(parameter) + + return parameters + + def _get_filter_parameters(self, path, method): + if not self._allows_filters(path, method): + return [] + parameters = [] + for filter_backend in self.view.filter_backends: + parameters += filter_backend().get_schema_operation_parameters(self.view) + return parameters + + def _allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + return method.lower() in ["get", "put", "patch", "delete"] + + def _get_pagination_parameters(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_operation_parameters(view) + + def _map_field(self, field): + + # Nested Serializers, `many` or not. + if isinstance(field, serializers.ListSerializer): + return { + 'type': 'array', + 'items': self._map_serializer(field.child) + } + if isinstance(field, serializers.Serializer): + data = self._map_serializer(field) + data['type'] = 'object' + return data + + # Related fields. + if isinstance(field, serializers.ManyRelatedField): + return { + 'type': 'array', + 'items': self._map_field(field.child_relation) + } + if isinstance(field, serializers.PrimaryKeyRelatedField): + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + return {'type': 'integer'} + + # ChoiceFields (single and multiple). + # Q: + # - Is 'type' required? + # - can we determine the TYPE of a choicefield? + if isinstance(field, serializers.MultipleChoiceField): + return { + 'type': 'array', + 'items': { + 'enum': list(field.choices) + }, + } + + if isinstance(field, serializers.ChoiceField): + return { + 'enum': list(field.choices), + } + + # ListField. + if isinstance(field, serializers.ListField): + return { + 'type': 'array', + } + + # DateField and DateTimeField type is string + if isinstance(field, serializers.DateField): + return { + 'type': 'string', + 'format': 'date', + } + + if isinstance(field, serializers.DateTimeField): + return { + 'type': 'string', + 'format': 'date-time', + } + + # Simplest cases, default to 'string' type: + FIELD_CLASS_SCHEMA_TYPE = { + serializers.BooleanField: 'boolean', + serializers.DecimalField: 'number', + serializers.FloatField: 'number', + serializers.IntegerField: 'integer', + + serializers.JSONField: 'object', + serializers.DictField: 'object', + } + return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + + def _map_serializer(self, serializer): + # Assuming we have a valid serializer instance. + # TODO: + # - field is Nested or List serializer. + # - Handle read_only/write_only for request/response differences. + # - could do this with readOnly/writeOnly and then filter dict. + required = [] + properties = {} + + for field in serializer.fields.values(): + if isinstance(field, serializers.HiddenField): + continue + + if field.required: + required.append(field.field_name) + + schema = self._map_field(field) + if field.read_only: + schema['readOnly'] = True + if field.write_only: + schema['writeOnly'] = True + if field.allow_null: + schema['nullable'] = True + + properties[field.field_name] = schema + return { + 'required': required, + 'properties': properties, + } + + def _get_request_body(self, path, method): + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + if not hasattr(view, 'get_serializer'): + return {} + + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if not isinstance(serializer, serializers.Serializer): + return {} + + content = self._map_serializer(serializer) + # No required fields for PATCH + if method == 'PATCH': + del content['required'] + # No read_only fields for request. + for name, schema in content['properties'].copy().items(): + if 'readOnly' in schema: + del content['properties'][name] + + return { + 'content': { + ct: {'schema': content} + for ct in self.content_types + } + } + + def _get_responses(self, path, method): + # TODO: Handle multiple codes. + content = {} + view = self.view + if hasattr(view, 'get_serializer'): + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if isinstance(serializer, serializers.Serializer): + content = self._map_serializer(serializer) + # No write_only fields for response. + for name, schema in content['properties'].copy().items(): + if 'writeOnly' in schema: + del content['properties'][name] + content['required'] = [f for f in content['required'] if f != name] + + return { + '200': { + 'content': { + ct: {'schema': content} + for ct in self.content_types + } + } + } diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py index 76437a20a..6724eb428 100644 --- a/rest_framework/schemas/utils.py +++ b/rest_framework/schemas/utils.py @@ -3,6 +3,9 @@ utils.py # Shared helper functions See schemas.__init__.py for package overview. """ +from django.db import models +from django.utils.translation import ugettext_lazy as _ + from rest_framework.mixins import RetrieveModelMixin @@ -22,3 +25,17 @@ def is_list_view(path, method, view): if path_components and '{' in path_components[-1]: return False return True + + +def get_pk_description(model, model_field): + if isinstance(model_field, models.AutoField): + value_type = _('unique integer value') + elif isinstance(model_field, models.UUIDField): + value_type = _('UUID string') + else: + value_type = _('unique value') + + return _('A {value_type} identifying this {name}.').format( + value_type=value_type, + name=model._meta.verbose_name, + ) diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py index fa5cdbdc7..527a23236 100644 --- a/rest_framework/schemas/views.py +++ b/rest_framework/schemas/views.py @@ -5,6 +5,7 @@ See schemas.__init__.py for package overview. """ from rest_framework import exceptions, renderers from rest_framework.response import Response +from rest_framework.schemas import coreapi from rest_framework.settings import api_settings from rest_framework.views import APIView @@ -19,10 +20,16 @@ class SchemaView(APIView): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.renderer_classes is None: - self.renderer_classes = [ - renderers.OpenAPIRenderer, - renderers.CoreJSONRenderer - ] + if coreapi.is_enabled(): + self.renderer_classes = [ + renderers.CoreAPIOpenAPIRenderer, + renderers.CoreJSONRenderer + ] + else: + self.renderer_classes = [ + renderers.OpenAPIRenderer, + renderers.JSONOpenAPIRenderer, + ] if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: self.renderer_classes += [renderers.BrowsableAPIRenderer] diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 1d5dc036f..3520eae36 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -52,7 +52,7 @@ DEFAULTS = { 'DEFAULT_FILTER_BACKENDS': (), # Schema - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema', # Throttling 'DEFAULT_THROTTLE_RATES': { diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_schemas.py b/tests/schemas/test_coreapi.py similarity index 94% rename from tests/test_schemas.py rename to tests/schemas/test_coreapi.py index 230f8f012..66275ade9 100644 --- a/tests/test_schemas.py +++ b/tests/schemas/test_coreapi.py @@ -16,15 +16,16 @@ from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.schemas import ( AutoSchema, ManualSchema, SchemaGenerator, get_schema_view ) +from rest_framework.schemas.coreapi import field_to_schema from rest_framework.schemas.generators import EndpointEnumerator -from rest_framework.schemas.inspectors import field_to_schema from rest_framework.schemas.utils import is_list_view from rest_framework.test import APIClient, APIRequestFactory from rest_framework.utils import formatting from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet, ModelViewSet -from .models import BasicModel, ForeignKeySource, ManyToManySource +from . import views +from ..models import BasicModel, ForeignKeySource, ManyToManySource factory = APIRequestFactory() @@ -133,11 +134,12 @@ class ExampleViewSet(ModelViewSet): pass -if coreapi: - schema_view = get_schema_view(title='Example API') -else: - def schema_view(request): - pass +with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + if coreapi: + schema_view = get_schema_view(title='Example API') + else: + def schema_view(request): + pass router = DefaultRouter() router.register('example', ExampleViewSet, basename='example') @@ -148,7 +150,7 @@ urlpatterns = [ @unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(ROOT_URLCONF='tests.test_schemas') +@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestRouterGeneratedSchema(TestCase): def test_anonymous_request(self): client = APIClient() @@ -400,12 +402,13 @@ class ExampleDetailView(APIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGenerator(TestCase): def setUp(self): self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^example/?$', views.ExampleListView.as_view()), + url(r'^example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -453,12 +456,13 @@ class TestSchemaGenerator(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(path, 'needs Django 2') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorDjango2(TestCase): def setUp(self): self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), + path('example/', views.ExampleListView.as_view()), + path('example//', views.ExampleDetailView.as_view()), + path('example//sub/', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -505,12 +509,13 @@ class TestSchemaGeneratorDjango2(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorNotAtRoot(TestCase): def setUp(self): self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^api/v1/example/?$', views.ExampleListView.as_view()), + url(r'^api/v1/example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^api/v1/example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -558,6 +563,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): def setUp(self): router = DefaultRouter() @@ -622,13 +628,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithRestrictedViewSets(TestCase): def setUp(self): router = DefaultRouter() router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example2', PermissionDeniedExampleViewSet, basename='example2') self.patterns = [ - url('^example/?$', ExampleListView.as_view()), + url('^example/?$', views.ExampleListView.as_view()), url(r'^', include(router.urls)) ] @@ -668,6 +675,7 @@ class ForeignKeySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithForeignKey(TestCase): def setUp(self): self.patterns = [ @@ -713,6 +721,7 @@ class ManyToManySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithManyToMany(TestCase): def setUp(self): self.patterns = [ @@ -747,6 +756,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class Test4605Regression(TestCase): def test_4605_regression(self): generator = SchemaGenerator() @@ -762,6 +772,7 @@ class CustomViewInspector(AutoSchema): pass +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): @@ -777,7 +788,7 @@ class TestAutoSchema(TestCase): assert isinstance(view.schema, CustomViewInspector) def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}): view = APIView() assert isinstance(view.schema, CustomViewInspector) @@ -971,6 +982,7 @@ class TestAutoSchema(TestCase): self.assertEqual(field_to_schema(case[0]), case[1]) +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) def test_docstring_is_not_stripped_by_get_description(): class ExampleDocstringAPIView(APIView): """ @@ -1007,25 +1019,25 @@ def test_docstring_is_not_stripped_by_get_description(): # Views for SchemaGenerationExclusionTests -class ExcludedAPIView(APIView): - schema = None +with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + class ExcludedAPIView(APIView): + schema = None - def get(self, request, *args, **kwargs): + def get(self, request, *args, **kwargs): + pass + + @api_view(['GET']) + @schema(None) + def excluded_fbv(request): + pass + + @api_view(['GET']) + def included_fbv(request): pass -@api_view(['GET']) -@schema(None) -def excluded_fbv(request): - pass - - -@api_view(['GET']) -def included_fbv(request): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class SchemaGenerationExclusionTests(TestCase): def setUp(self): self.patterns = [ @@ -1078,11 +1090,6 @@ class SchemaGenerationExclusionTests(TestCase): assert should_include == expected -@api_view(["GET"]) -def simple_fbv(request): - pass - - class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel @@ -1118,11 +1125,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestURLNamingCollisions(TestCase): """ Ref: https://github.com/encode/django-rest-framework/issues/4704 """ def test_manually_routing_nested_routes(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test', simple_fbv), url(r'^test/list/', simple_fbv), @@ -1228,6 +1240,10 @@ class TestURLNamingCollisions(TestCase): def test_url_under_same_key_not_replaced_another(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test/list/', simple_fbv), url(r'^test/(?P\d+)/list/', simple_fbv), @@ -1302,10 +1318,8 @@ def test_head_and_options_methods_are_excluded(): assert inspector.get_allowed_methods(callback) == ["GET"] -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestAutoSchemaAllowsFilters: - class MockAPIView(APIView): - filter_backends = [filters.OrderingFilter] +class MockAPIView(APIView): + filter_backends = [filters.OrderingFilter] def _test(self, method): view = self.MockAPIView() diff --git a/tests/schemas/test_get_schema_view.py b/tests/schemas/test_get_schema_view.py new file mode 100644 index 000000000..f582c6495 --- /dev/null +++ b/tests/schemas/test_get_schema_view.py @@ -0,0 +1,20 @@ +import pytest +from django.test import TestCase, override_settings + +from rest_framework import renderers +from rest_framework.schemas import coreapi, get_schema_view, openapi + + +class GetSchemaViewTests(TestCase): + """For the get_schema_view() helper.""" + def test_openapi(self): + schema_view = get_schema_view(title="With OpenAPI") + assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator) + assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes + + @pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed') + def test_coreapi(self): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + schema_view = get_schema_view(title="With CoreAPI") + assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator) + assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes diff --git a/tests/test_generateschema.py b/tests/schemas/test_managementcommand.py similarity index 57% rename from tests/test_generateschema.py rename to tests/schemas/test_managementcommand.py index a6a1f2bed..e5960f2b0 100644 --- a/tests/test_generateschema.py +++ b/tests/schemas/test_managementcommand.py @@ -6,7 +6,8 @@ from django.core.management import call_command from django.test import TestCase from django.test.utils import override_settings -from rest_framework.compat import coreapi +from rest_framework.compat import uritemplate, yaml +from rest_framework.management.commands import generateschema from rest_framework.utils import formatting, json from rest_framework.views import APIView @@ -21,15 +22,43 @@ urlpatterns = [ ] -@override_settings(ROOT_URLCONF='tests.test_generateschema') -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') +@override_settings(ROOT_URLCONF=__name__) +@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed') class GenerateSchemaTests(TestCase): """Tests for management command generateschema.""" def setUp(self): self.out = io.StringIO() + def test_command_detects_schema_generation_mode(self): + """Switching between CoreAPI & OpenAPI""" + command = generateschema.Command() + assert command.get_mode() == generateschema.OPENAPI_MODE + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): + assert command.get_mode() == generateschema.COREAPI_MODE + + @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') def test_renders_default_schema_with_custom_title_url_and_description(self): + call_command('generateschema', + '--title=SampleAPI', + '--url=http://api.sample.com', + '--description=Sample description', + stdout=self.out) + # Check valid YAML was output. + schema = yaml.load(self.out.getvalue()) + assert schema['openapi'] == '3.0.2' + + def test_renders_openapi_json_schema(self): + call_command('generateschema', + '--format=openapi-json', + stdout=self.out) + # Check valid JSON was output. + out_json = json.loads(self.out.getvalue()) + assert out_json['openapi'] == '3.0.2' + + @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) + def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self): expected_out = """info: description: Sample description title: SampleAPI @@ -50,7 +79,8 @@ class GenerateSchemaTests(TestCase): self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) - def test_renders_openapi_json_schema(self): + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) + def test_coreapi_renders_openapi_json_schema(self): expected_out = { "openapi": "3.0.0", "info": { @@ -78,6 +108,7 @@ class GenerateSchemaTests(TestCase): self.assertDictEqual(out_json, expected_out) + @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) def test_renders_corejson_schema(self): expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" call_command('generateschema', diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py new file mode 100644 index 000000000..2ddf54f01 --- /dev/null +++ b/tests/schemas/test_openapi.py @@ -0,0 +1,245 @@ +import pytest +from django.conf.urls import url +from django.test import RequestFactory, TestCase, override_settings + +from rest_framework import filters, generics, pagination, routers, serializers +from rest_framework.compat import uritemplate +from rest_framework.request import Request +from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator + +from . import views + + +def create_request(path): + factory = RequestFactory() + request = Request(factory.get(path)) + return request + + +def create_view(view_cls, method, request): + generator = SchemaGenerator() + view = generator.create_view(view_cls.as_view(), method, request) + return view + + +class TestBasics(TestCase): + def dummy_view(request): + pass + + def test_filters(self): + classes = [filters.SearchFilter, filters.OrderingFilter] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + def test_pagination(self): + classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + +@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') +class TestOperationIntrospection(TestCase): + + def test_path_without_parameters(self): + path = '/example/' + method = 'GET' + + view = create_view( + views.ExampleListView, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + operation = inspector.get_operation(path, method) + assert operation == { + 'operationId': 'ListExamples', + 'parameters': [], + 'responses': {'200': {'content': {'application/json': {'schema': {}}}}}, + } + + def test_path_with_id_parameter(self): + path = '/example/{id}/' + method = 'GET' + + view = create_view( + views.ExampleDetailView, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + parameters = inspector._get_path_parameters(path, method) + assert parameters == [{ + 'description': '', + 'in': 'path', + 'name': 'id', + 'required': True, + 'schema': { + 'type': 'string', + }, + }] + + def test_request_body(self): + path = '/' + method = 'POST' + + class Serializer(serializers.Serializer): + text = serializers.CharField() + read_only = serializers.CharField(read_only=True) + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + request_body = inspector._get_request_body(path, method) + assert request_body['content']['application/json']['schema']['required'] == ['text'] + assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text'] + + def test_response_body_generation(self): + path = '/' + method = 'POST' + + class Serializer(serializers.Serializer): + text = serializers.CharField() + write_only = serializers.CharField(write_only=True) + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + responses = inspector._get_responses(path, method) + assert responses['200']['content']['application/json']['schema']['required'] == ['text'] + assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text'] + + def test_response_body_nested_serializer(self): + path = '/' + method = 'POST' + + class NestedSerializer(serializers.Serializer): + number = serializers.IntegerField() + + class Serializer(serializers.Serializer): + text = serializers.CharField() + nested = NestedSerializer() + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path), + ) + inspector = AutoSchema() + inspector.view = view + + responses = inspector._get_responses(path, method) + schema = responses['200']['content']['application/json']['schema'] + assert sorted(schema['required']) == ['nested', 'text'] + assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] + assert schema['properties']['nested']['type'] == 'object' + assert list(schema['properties']['nested']['properties'].keys()) == ['number'] + assert schema['properties']['nested']['required'] == ['number'] + + def test_operation_id_generation(self): + path = '/' + method = 'GET' + + view = create_view( + views.ExampleGenericAPIView, + method, + create_request(path), + ) + inspector = AutoSchema() + inspector.view = view + + operationId = inspector._get_operation_id(path, method) + assert operationId == 'ListExamples' + + def test_repeat_operation_ids(self): + router = routers.SimpleRouter() + router.register('account', views.ExampleGenericViewSet, basename="account") + urlpatterns = router.urls + + generator = SchemaGenerator(patterns=urlpatterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + schema_str = str(schema) + print(schema_str) + assert schema_str.count("operationId") == 2 + assert schema_str.count("newExample") == 1 + assert schema_str.count("oldExample") == 1 + + +@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'}) +class TestGenerator(TestCase): + + def test_override_settings(self): + assert isinstance(views.ExampleListView.schema, AutoSchema) + + def test_paths_construction(self): + """Construction of the `paths` key.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + generator._initialise_endpoints() + + paths = generator.get_paths() + + assert '/example/' in paths + example_operations = paths['/example/'] + assert len(example_operations) == 2 + assert 'get' in example_operations + assert 'post' in example_operations + + def test_schema_construction(self): + """Construction of the top level dictionary.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert 'openapi' in schema + assert 'paths' in schema + + def test_serializer_datefield(self): + patterns = [ + url(r'^example/?$', views.ExampleGenericViewSet.as_view({"get": "get"})), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + response = schema['paths']['/example/']['get']['responses'] + response_schema = response['200']['content']['application/json']['schema']['properties'] + + assert response_schema['date']['type'] == response_schema['datetime']['type'] == 'string' + + assert response_schema['date']['format'] == 'date' + assert response_schema['datetime']['format'] == 'date-time' diff --git a/tests/schemas/views.py b/tests/schemas/views.py new file mode 100644 index 000000000..dc0d6065b --- /dev/null +++ b/tests/schemas/views.py @@ -0,0 +1,58 @@ +from rest_framework import generics, permissions, serializers +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework.viewsets import GenericViewSet + + +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + +# Generics. +class ExampleSerializer(serializers.Serializer): + date = serializers.DateField() + datetime = serializers.DateTimeField() + + +class ExampleGenericAPIView(generics.GenericAPIView): + serializer_class = ExampleSerializer + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data) + + +class ExampleGenericViewSet(GenericViewSet): + serializer_class = ExampleSerializer + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data) + + @action(detail=False) + def new(self, *args, **kwargs): + pass + + @action(detail=False) + def old(self, *args, **kwargs): + pass