From c0a31ed0a335556fa036733f84f549c8bd41e0a2 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Tue, 19 Mar 2019 15:51:59 +0100 Subject: [PATCH] Added OpenAPI Schema Generation. --- rest_framework/filters.py | 29 ++ .../management/commands/generateschema.py | 33 +-- rest_framework/pagination.py | 94 +++++- rest_framework/schemas/generators.py | 257 +++++++++++------ rest_framework/schemas/inspectors.py | 268 +++++++++++++++++- rest_framework/settings.py | 2 +- tests/schemas/__init__.py | 0 .../test_coreapi.py} | 92 +++--- tests/schemas/test_openapi.py | 118 ++++++++ tests/schemas/views.py | 19 ++ tests/test_generateschema.py | 53 +--- 11 files changed, 738 insertions(+), 227 deletions(-) create mode 100644 tests/schemas/__init__.py rename tests/{test_schemas.py => schemas/test_coreapi.py} (94%) create mode 100644 tests/schemas/test_openapi.py create mode 100644 tests/schemas/views.py diff --git a/rest_framework/filters.py b/rest_framework/filters.py index bb1b86586..53b77ff39 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -40,6 +40,9 @@ class BaseFilterBackend(object): assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. @@ -159,6 +162,19 @@ class SearchFilter(BaseFilterBackend): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class OrderingFilter(BaseFilterBackend): # The URL query parameter used for the ordering. @@ -290,6 +306,19 @@ class OrderingFilter(BaseFilterBackend): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.ordering_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.ordering_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class DjangoObjectPermissionsFilter(BaseFilterBackend): """ diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 591073ba0..926e8db39 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -1,25 +1,21 @@ from django.core.management.base import BaseCommand -from rest_framework.compat import coreapi -from rest_framework.renderers import ( - CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer -) -from rest_framework.schemas.generators import SchemaGenerator +from rest_framework.compat import yaml +from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.utils import json class Command(BaseCommand): help = "Generates configured API schema for project." def add_arguments(self, parser): - parser.add_argument('--title', dest="title", default=None, type=str) + parser.add_argument('--title', dest="title", default='', type=str) parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str) - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) def handle(self, *args, **options): - assert coreapi is not None, 'coreapi must be installed.' - - generator = SchemaGenerator( + generator = OpenAPISchemaGenerator( url=options['url'], title=options['title'], description=options['description'] @@ -27,15 +23,10 @@ class Command(BaseCommand): schema = generator.get_schema(request=None, public=True) - renderer = self.get_renderer(options['format']) - output = renderer.render(schema, renderer_context={}) - self.stdout.write(output.decode('utf-8')) + # TODO: Handle via renderer? More options? + if options['format'] == 'openapi': + output = yaml.dump(schema, default_flow_style=False) + else: + output = json.dumps(schema, indent=2) - def get_renderer(self, format): - renderer_cls = { - 'corejson': CoreJSONRenderer, - 'openapi': OpenAPIRenderer, - 'openapi-json': JSONOpenAPIRenderer, - }[format] - - return renderer_cls() + self.stdout.write(output) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index b11d7cdf3..e93095d10 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -152,6 +152,9 @@ class BasePagination(object): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class PageNumberPagination(BasePagination): """ @@ -305,6 +308,32 @@ class PageNumberPagination(BasePagination): ) return fields + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.page_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ) + return parameters + class LimitOffsetPagination(BasePagination): """ @@ -434,6 +463,15 @@ class LimitOffsetPagination(BasePagination): context = self.get_html_context() return template.render(context) + def get_count(self, queryset): + """ + Determine an object count, supporting either querysets or regular lists. + """ + try: + return queryset.count() + except (AttributeError, TypeError): + return len(queryset) + def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' @@ -458,14 +496,28 @@ class LimitOffsetPagination(BasePagination): ) ] - def get_count(self, queryset): - """ - Determine an object count, supporting either querysets or regular lists. - """ - try: - return queryset.count() - except (AttributeError, TypeError): - return len(queryset) + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.limit_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.limit_query_description), + 'schema': { + 'type': 'integer', + }, + }, + { + 'name': self.offset_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.offset_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + return parameters class CursorPagination(BasePagination): @@ -820,3 +872,29 @@ class CursorPagination(BasePagination): ) ) return fields + + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.cursor_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.cursor_query_description), + 'schema': { + 'type': 'integer', + }, + } + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + } + ) + return parameters diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index db226a6c1..57cf91d16 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -193,6 +193,10 @@ class EndpointEnumerator(object): """ Given a URL conf regex, return a URI template string. """ + # ???: Would it be feasible to adjust this such that we generate the + # path, plus the kwargs, plus the type from the convertor, such that we + # could feed that straight into the parameter schema object? + path = simplify_regex(path_regex) # Strip Django 2.0 convertors as they are incompatible with uritemplate format @@ -232,35 +236,18 @@ class EndpointEnumerator(object): return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class SchemaGenerator(object): - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } +class BaseSchemaGenerator(object): endpoint_inspector_cls = EndpointEnumerator - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - if url and not url.endswith('/'): url += '/' - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.patterns = patterns @@ -270,36 +257,15 @@ class SchemaGenerator(object): self.url = url self.endpoints = None - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ + def _initialise_endpoints(self): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - def get_links(self, request=None): + def _get_paths_and_endpoints(self, request): """ - Return a dictionary containing all the links that should be - included in the API schema. + Generate (path, method, view) given (path, method, callback) for paths. """ - links = LinkNode() - - # Generate (path, method, view) given (path, method, callback). paths = [] view_endpoints = [] for path, method, callback in self.endpoints: @@ -308,22 +274,48 @@ class SchemaGenerator(object): paths.append(path) view_endpoints.append((path, method, view)) - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) + return paths, view_endpoints - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls(**getattr(callback, 'initkwargs', {})) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) - return links + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) - # Methods used when we generate a view instance from the raw callback... + if request is not None: + view.request = clone_request(request, method) + + return view + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + + def get_schema(self, request=None, public=False): + raise NotImplementedError(".get_schema() must be implemented in subclasses.") def determine_path_prefix(self, paths): """ @@ -356,29 +348,6 @@ class SchemaGenerator(object): prefixes.append('/' + prefix + '/') return common_path(prefixes) - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. @@ -392,23 +361,77 @@ class SchemaGenerator(object): return False return True - def coerce_path(self, path, method, view): + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") + Return a dictionary containing all the links that should be + included in the API schema. """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) # Method for generating the link layout.... - def get_keys(self, subpath, method, view): """ Return a list of keys that should be used to layout a link within @@ -452,3 +475,55 @@ class SchemaGenerator(object): # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] + + +class OpenAPISchemaGenerator(BaseSchemaGenerator): + + def get_info(self): + info = { + 'title': self.title, + 'version': 'TODO', + } + + if self.description is not None: + info['description'] = self.description + + return info + + def get_paths(self, request=None): + result = {} + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + operation = view.schema.get_operation(path, method) + subpath = '/' + path[len(prefix):] + result.setdefault(subpath, {}) + result[subpath][method.lower()] = operation + + return result + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + + paths = self.get_paths(None if public else request) + if not paths: + return None + + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + return schema diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 85142edce..1bf418d72 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -179,20 +179,6 @@ class ViewInspector(object): def view(self): self._view = None - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - raise NotImplementedError(".get_link() must be overridden.") - class AutoSchema(ViewInspector): """ @@ -213,6 +199,17 @@ class AutoSchema(ViewInspector): self._manual_fields = manual_fields def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ fields = self.get_path_fields(path, method) fields += self.get_serializer_fields(path, method) fields += self.get_pagination_fields(path, method) @@ -509,3 +506,246 @@ class DefaultSchema(ViewInspector): inspector = inspector_class() inspector.view = instance return inspector + + +class OpenAPIAutoSchema(ViewInspector): + + content_types = ['application/json'] + + def get_operation(self, path, method): + operation = {} + + parameters = [] + parameters += self._get_path_parameters(path, method) + parameters += self._get_pagination_parameters(path, method) + parameters += self._get_filter_parameters(path, method) + operation['parameters'] = parameters + + request_body = self._get_request_body(path, method) + if request_body: + operation['requestBody'] = request_body + operation['responses'] = self._get_responses(path, method) + + return operation + + def _get_path_parameters(self, path, method): + """ + Return a list of parameters from templated path variables. + """ + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + parameters = [] + + for variable in uritemplate.variables(path): + description = '' + if model is not None: # TODO: test this. + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + parameter = { + "name": variable, + "in": "path", + "required": True, + "description": description, + 'schema': { + 'type': 'string', # TODO: integer, pattern, ... + }, + } + parameters.append(parameter) + + return parameters + + def _get_filter_parameters(self, path, method): + if not self._allows_filters(path, method): + return [] + parameters = [] + for filter_backend in self.view.filter_backends: + parameters += filter_backend().get_schema_operation_parameters(self.view) + return parameters + + def _allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + return method.lower() in ["get", "put", "patch", "delete"] + + def _get_pagination_parameters(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_operation_parameters(view) + + def _map_field(self, field): + + # Nested Serializers, `many` or not. + if isinstance(field, serializers.ListSerializer): + return { + 'type': 'array', + 'items': self._map_serializer(field.child) + } + if isinstance(field, serializers.Serializer): + return { + 'type': 'object', + 'properties': self._map_serializer(field) + } + + # Related fields. + if isinstance(field, serializers.ManyRelatedField): + return { + 'type': 'array', + 'items': self._map_field(field.child_relation) + } + if isinstance(field, serializers.PrimaryKeyRelatedField): + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + return {'type': 'integer'} + + # ChoiceFields (single and multiple). + # Q: + # - Is 'type' required? + # - can we determine the TYPE of a choicefield? + if isinstance(field, serializers.MultipleChoiceField): + return { + 'type': 'array', + 'items': { + 'enum': list(field.choices) + }, + } + + if isinstance(field, serializers.ChoiceField): + return { + 'enum': list(field.choices), + } + + # ListField. + if isinstance(field, serializers.ListField): + return { + 'type': 'array', + } + + # Simplest cases, default to 'string' type: + FIELD_CLASS_SCHEMA_TYPE = { + serializers.BooleanField: 'boolean', + serializers.DecimalField: 'number', + serializers.FloatField: 'number', + serializers.IntegerField: 'integer', + serializers.DateField: 'date', + serializers.DateTimeField: 'date-time', + serializers.JSONField: 'object', + serializers.DictField: 'object', + } + return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + + def _map_serializer(self, serializer): + # Assuming we have a valid serializer instance. + # TODO: + # - field is Nested or List serializer. + # - Handle read_only/write_only for request/response differences. + # - could do this with readOnly/writeOnly and then filter dict. + required = [] + properties = {} + + for field in serializer.fields.values(): + if isinstance(field, serializers.HiddenField): + continue + + if field.required: + required.append(field.field_name) + + schema = self._map_field(field) + if field.read_only: + schema['readOnly'] = True + if field.write_only: + schema['writeOnly'] = True + if field.allow_null: + schema['nullable'] = True + + properties[field.field_name] = schema + return { + 'required': required, + 'properties': properties, + } + + def _get_request_body(self, path, method): + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + if not hasattr(view, 'get_serializer'): + return {} + + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if not isinstance(serializer, serializers.Serializer): + return {} + + content = self._map_serializer(serializer) + # No required fields for PATCH + if method == 'PATCH': + del content['required'] + # No read_only fields for request. + for name, schema in content['properties'].items(): + if 'readOnly' in schema: + del content['properties']['name'] + + return { + 'content': {ct: content for ct in self.content_types} + } + + def _get_responses(self, path, method): + # TODO: Handle multiple codes. + content = {} + view = self.view + if hasattr(view, 'get_serializer'): + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if isinstance(serializer, serializers.Serializer): + content = self._map_serializer(serializer) + # No write_only fields for response. + for name, schema in content['properties'].items(): + if 'writeOnly' in schema: + del content['properties']['name'] + + return { + '200': { + 'content': {ct: content for ct in self.content_types} + } + } diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8db9c81ed..a22050ea9 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -56,7 +56,7 @@ DEFAULTS = { 'DEFAULT_FILTER_BACKENDS': (), # Schema - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema', # Throttling 'DEFAULT_THROTTLE_RATES': { diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_schemas.py b/tests/schemas/test_coreapi.py similarity index 94% rename from tests/test_schemas.py rename to tests/schemas/test_coreapi.py index 3cb9e0cda..ee8fdd007 100644 --- a/tests/test_schemas.py +++ b/tests/schemas/test_coreapi.py @@ -24,7 +24,8 @@ from rest_framework.utils import formatting from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet, ModelViewSet -from .models import BasicModel, ForeignKeySource, ManyToManySource +from . import views +from ..models import BasicModel, ForeignKeySource, ManyToManySource factory = APIRequestFactory() @@ -148,7 +149,7 @@ urlpatterns = [ @unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(ROOT_URLCONF='tests.test_schemas') +@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestRouterGeneratedSchema(TestCase): def test_anonymous_request(self): client = APIClient() @@ -382,30 +383,14 @@ class MethodLimitedViewSet(ExampleViewSet): http_method_names = ['get', 'head', 'options'] -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGenerator(TestCase): def setUp(self): self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^example/?$', views.ExampleListView.as_view()), + url(r'^example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -453,12 +438,13 @@ class TestSchemaGenerator(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(path, 'needs Django 2') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorDjango2(TestCase): def setUp(self): self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), + path('example/', views.ExampleListView.as_view()), + path('example//', views.ExampleDetailView.as_view()), + path('example//sub/', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -505,12 +491,13 @@ class TestSchemaGeneratorDjango2(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorNotAtRoot(TestCase): def setUp(self): self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^api/v1/example/?$', views.ExampleListView.as_view()), + url(r'^api/v1/example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^api/v1/example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -558,6 +545,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): def setUp(self): router = DefaultRouter() @@ -622,13 +610,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithRestrictedViewSets(TestCase): def setUp(self): router = DefaultRouter() router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example2', PermissionDeniedExampleViewSet, basename='example2') self.patterns = [ - url('^example/?$', ExampleListView.as_view()), + url('^example/?$', views.ExampleListView.as_view()), url(r'^', include(router.urls)) ] @@ -668,6 +657,7 @@ class ForeignKeySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithForeignKey(TestCase): def setUp(self): self.patterns = [ @@ -713,6 +703,7 @@ class ManyToManySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithManyToMany(TestCase): def setUp(self): self.patterns = [ @@ -747,6 +738,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class Test4605Regression(TestCase): def test_4605_regression(self): generator = SchemaGenerator() @@ -762,6 +754,7 @@ class CustomViewInspector(AutoSchema): pass +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): @@ -777,7 +770,7 @@ class TestAutoSchema(TestCase): assert isinstance(view.schema, CustomViewInspector) def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}): view = APIView() assert isinstance(view.schema, CustomViewInspector) @@ -971,6 +964,7 @@ class TestAutoSchema(TestCase): self.assertEqual(field_to_schema(case[0]), case[1]) +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) def test_docstring_is_not_stripped_by_get_description(): class ExampleDocstringAPIView(APIView): """ @@ -1014,20 +1008,19 @@ class ExcludedAPIView(APIView): pass -@api_view(['GET']) -@schema(None) -def excluded_fbv(request): - pass - - -@api_view(['GET']) -def included_fbv(request): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class SchemaGenerationExclusionTests(TestCase): def setUp(self): + @api_view(['GET']) + @schema(None) + def excluded_fbv(request): + pass + + @api_view(['GET']) + def included_fbv(request): + pass + self.patterns = [ url('^excluded-cbv/$', ExcludedAPIView.as_view()), url('^excluded-fbv/$', excluded_fbv), @@ -1078,11 +1071,6 @@ class SchemaGenerationExclusionTests(TestCase): assert should_include == expected -@api_view(["GET"]) -def simple_fbv(request): - pass - - class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel @@ -1118,11 +1106,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestURLNamingCollisions(TestCase): """ Ref: https://github.com/encode/django-rest-framework/issues/4704 """ def test_manually_routing_nested_routes(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test', simple_fbv), url(r'^test/list/', simple_fbv), @@ -1228,6 +1221,10 @@ class TestURLNamingCollisions(TestCase): def test_url_under_same_key_not_replaced_another(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test/list/', simple_fbv), url(r'^test/(?P\d+)/list/', simple_fbv), @@ -1303,7 +1300,8 @@ def test_head_and_options_methods_are_excluded(): @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestAutoSchemaAllowsFilters(object): +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) +class TestAutoSchemaAllowsFilters(TestCase): class MockAPIView(APIView): filter_backends = [filters.OrderingFilter] diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py new file mode 100644 index 000000000..23937a7d1 --- /dev/null +++ b/tests/schemas/test_openapi.py @@ -0,0 +1,118 @@ +from django.conf.urls import url +from django.test import RequestFactory, TestCase, override_settings + +from rest_framework import filters, pagination +from rest_framework.request import Request +from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.schemas.inspectors import OpenAPIAutoSchema + +from . import views + + +def create_request(path): + factory = RequestFactory() + request = Request(factory.get(path)) + return request + + +def create_view(view_cls, method, request): + generator = OpenAPISchemaGenerator() + view = generator.create_view(view_cls.as_view(), method, request) + return view + + +class TestBasics(TestCase): + def dummy_view(request): + pass + + def test_filters(self): + classes = [filters.SearchFilter, filters.OrderingFilter] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + def test_pagination(self): + classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + +class TestOperationIntrospection(TestCase): + + def test_path_without_parameters(self): + path = '/example/' + method = 'GET' + + view = create_view( + views.ExampleListView, + method, + create_request(path) + ) + inspector = OpenAPIAutoSchema() + inspector.view = view + + operation = inspector.get_operation(path, method) + assert operation == { + 'parameters': [], + 'responses': {'200': {'content': {'application/json': {}}}}, + } + + def test_path_with_id_parameter(self): + path = '/example/{id}/' + method = 'GET' + + view = create_view( + views.ExampleDetailView, + method, + create_request(path) + ) + inspector = OpenAPIAutoSchema() + inspector.view = view + + parameters = inspector._get_path_parameters(path, method) + assert parameters == [{ + 'description': '', + 'in': 'path', + 'name': 'id', + 'required': True, + 'schema': { + 'type': 'string', + }, + }] + + +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema'}) +class TestGenerator(TestCase): + + def test_override_settings(self): + assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema) + + def test_paths_construction(self): + """Construction of the `paths` key.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = OpenAPISchemaGenerator(patterns=patterns) + generator._initialise_endpoints() + + paths = generator.get_paths() + + assert '/example/' in paths + example_operations = paths['/example/'] + assert len(example_operations) == 2 + assert 'get' in example_operations + assert 'post' in example_operations + + def test_schema_construction(self): + """Construction of the top level dictionary.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = OpenAPISchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert 'openapi' in schema + assert 'paths' in schema diff --git a/tests/schemas/views.py b/tests/schemas/views.py new file mode 100644 index 000000000..c368ba7e5 --- /dev/null +++ b/tests/schemas/views.py @@ -0,0 +1,19 @@ +from rest_framework import permissions +from rest_framework.views import APIView + + +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass diff --git a/tests/test_generateschema.py b/tests/test_generateschema.py index 915c6ea05..978869620 100644 --- a/tests/test_generateschema.py +++ b/tests/test_generateschema.py @@ -7,8 +7,8 @@ from django.test import TestCase from django.test.utils import override_settings from django.utils import six -from rest_framework.compat import coreapi -from rest_framework.utils import formatting, json +from rest_framework.compat import coreapi, yaml +from rest_framework.utils import json from rest_framework.views import APIView @@ -31,58 +31,21 @@ class GenerateSchemaTests(TestCase): self.out = six.StringIO() @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') + @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') def test_renders_default_schema_with_custom_title_url_and_description(self): - expected_out = """info: - description: Sample description - title: SampleAPI - version: '' - openapi: 3.0.0 - paths: - /: - get: - operationId: list - servers: - - url: http://api.sample.com/ - """ call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out) - - self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) + # Check valid YAML was output. + schema = yaml.load(self.out.getvalue()) + assert schema['openapi'] == '3.0.2' def test_renders_openapi_json_schema(self): - expected_out = { - "openapi": "3.0.0", - "info": { - "version": "", - "title": "", - "description": "" - }, - "servers": [ - { - "url": "" - } - ], - "paths": { - "/": { - "get": { - "operationId": "list" - } - } - } - } call_command('generateschema', '--format=openapi-json', stdout=self.out) + # Check valid JSON was output. out_json = json.loads(self.out.getvalue()) - - self.assertDictEqual(out_json, expected_out) - - def test_renders_corejson_schema(self): - expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" - call_command('generateschema', - '--format=corejson', - stdout=self.out) - self.assertIn(expected_out, self.out.getvalue()) + assert out_json['openapi'] == '3.0.2'