diff --git a/rest_framework/management/__init__.py b/rest_framework/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rest_framework/management/commands/__init__.py b/rest_framework/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rest_framework/management/commands/generate_schema.py b/rest_framework/management/commands/generate_schema.py new file mode 100644 index 000000000..9ac17ec9c --- /dev/null +++ b/rest_framework/management/commands/generate_schema.py @@ -0,0 +1,46 @@ +from django.core.management.base import BaseCommand + +from rest_framework.compat import coreapi +from rest_framework.renderers import CoreJSONRenderer, OpenAPIRenderer +from rest_framework.settings import api_settings + + +class Command(BaseCommand): + help = "Generates configured API schema for project." + + def add_arguments(self, parser): + # TODO + # SchemaGenerator allows passing: + # + # - title + # - url + # - description + # - urlconf + # - patterns + # + # Don't particularly want to pass these on the command-line. + # conf file? + # + # Other options to consider: + # - indent + # - ... + pass + + def handle(self, *args, **options): + assert coreapi is not None, 'coreapi must be installed.' + + generator_class = api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS() + generator = generator_class() + + schema = generator.get_schema(request=None, public=True) + + renderer = self.get_renderer('openapi') + output = renderer.render(schema) + + self.stdout.write(output) + + def get_renderer(self, format): + return { + 'corejson': CoreJSONRenderer(), + 'openapi': OpenAPIRenderer() + } diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a9645cc89..1b850db92 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -9,6 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API. from __future__ import unicode_literals import base64 +import urllib.parse as urlparse from collections import OrderedDict from django import forms @@ -24,7 +25,7 @@ from django.utils.html import mark_safe from rest_framework import VERSION, exceptions, serializers, status from rest_framework.compat import ( - INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, + INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema, pygments_css ) from rest_framework.exceptions import ParseError @@ -932,3 +933,95 @@ class CoreJSONRenderer(BaseRenderer): indent = bool(renderer_context.get('indent', 0)) codec = coreapi.codecs.CoreJSONCodec() return codec.dump(data, indent=indent) + + +class OpenAPIRenderer: + CLASS_TO_TYPENAME = { + coreschema.Object: 'object', + coreschema.Array: 'array', + coreschema.Number: 'number', + coreschema.Integer: 'integer', + coreschema.String: 'string', + coreschema.Boolean: 'boolean', + } + + def __init__(self): + assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' + + def get_schema(self, instance): + schema = {} + if instance.__class__ in self.CLASS_TO_TYPENAME: + schema['type'] = self.CLASS_TO_TYPENAME[instance.__class__] + schema['title'] = instance.title, + schema['description'] = instance.description + if hasattr(instance, 'enum'): + schema['enum'] = instance.enum + return schema + + def get_parameters(self, link): + parameters = [] + for field in link.fields: + if field.location not in ['path', 'query']: + continue + parameter = { + 'name': field.name, + 'in': field.location, + } + if field.required: + parameter['required'] = True + if field.description: + parameter['description'] = field.description + if field.schema: + parameter['schema'] = self.get_schema(field.schema) + parameters.append(parameter) + return parameters + + def get_operation(self, link, name, tag): + operation_id = "%s_%s" % (tag, name) if tag else name + parameters = self.get_parameters(link) + + operation = { + 'operationId': operation_id, + } + if link.title: + operation['summary'] = link.title + if link.description: + operation['description'] = link.description + if parameters: + operation['parameters'] = parameters + if tag: + operation['tags'] = [tag] + return operation + + def get_paths(self, document): + paths = {} + + tag = None + for name, link in document.links.items(): + path = urlparse.urlparse(link.url).path + method = link.action.lower() + paths.setdefault(path, {}) + paths[path][method] = self.get_operation(link, name, tag=tag) + + for tag, section in document.data.items(): + for name, link in section.links.items(): + path = urlparse.urlparse(link.url).path + method = link.action.lower() + paths.setdefault(path, {}) + paths[path][method] = self.get_operation(link, name, tag=tag) + + return paths + + def render(self, data, media_type=None, renderer_context=None): + return json.dumps({ + 'openapi': '3.0.0', + 'info': { + 'version': '', + 'title': data.title, + 'description': data.description + }, + 'servers': [{ + 'url': data.url + }], + 'paths': self.get_paths(data) + }, indent=4) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 8794c9967..33eb874ac 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -241,35 +241,18 @@ class EndpointEnumerator(object): return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class SchemaGenerator(object): - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } +class BaseSchemaGenerator(object): endpoint_inspector_cls = EndpointEnumerator - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - if url and not url.endswith('/'): url += '/' - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.patterns = patterns @@ -279,36 +262,15 @@ class SchemaGenerator(object): self.url = url self.endpoints = None - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ + def _initialise_endpoints(self): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - def get_links(self, request=None): + def _get_paths_and_endpoints(self, request): """ - Return a dictionary containing all the links that should be - included in the API schema. + Generate (path, method, view) given (path, method, callback) for paths. """ - links = LinkNode() - - # Generate (path, method, view) given (path, method, callback). paths = [] view_endpoints = [] for path, method, callback in self.endpoints: @@ -317,22 +279,48 @@ class SchemaGenerator(object): paths.append(path) view_endpoints.append((path, method, view)) - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) + return paths, view_endpoints - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls(**getattr(callback, 'initkwargs', {})) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) - return links + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) - # Methods used when we generate a view instance from the raw callback... + if request is not None: + view.request = clone_request(request, method) + + return view + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + + def get_schema(self, request=None, public=False): + raise NotImplementedError(".get_schema() must be implemented in subclasses.") def determine_path_prefix(self, paths): """ @@ -365,29 +353,6 @@ class SchemaGenerator(object): prefixes.append('/' + prefix + '/') return common_path(prefixes) - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. @@ -401,23 +366,77 @@ class SchemaGenerator(object): return False return True - def coerce_path(self, path, method, view): + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") + Return a dictionary containing all the links that should be + included in the API schema. """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) # Method for generating the link layout.... - def get_keys(self, subpath, method, view): """ Return a list of keys that should be used to layout a link within diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index b90f60e08..4c2fc10ee 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -174,20 +174,6 @@ class ViewInspector(object): def view(self): self._view = None - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - raise NotImplementedError(".get_link() must be overridden.") - class AutoSchema(ViewInspector): """ @@ -208,6 +194,17 @@ class AutoSchema(ViewInspector): self._manual_fields = manual_fields def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ fields = self.get_path_fields(path, method) fields += self.get_serializer_fields(path, method) fields += self.get_pagination_fields(path, method) @@ -501,3 +498,44 @@ class DefaultSchema(ViewInspector): inspector = inspector_class() inspector.view = instance return inspector + + +class OpenAPIAutoSchema(ViewInspector): + + def get_operation(self, path, method): + return { + 'parameters': self.get_path_parameters(path, method), + } + + def get_path_parameters(self, path, method): + """ + Return a list of parameters from templated path variables. + """ + assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' + + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + parameters = [] + + for variable in uritemplate.variables(path): + description = '' + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + parameter = { + "name": variable, + "in": "path", + "required": True, + "description": description, + } + parameters.append(parameter) + + return parameters diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8db9c81ed..7b83612a4 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -57,6 +57,7 @@ DEFAULTS = { # Schema 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + 'DEFAULT_SCHEMA_GENERATOR_CLASS': 'rest_framework.schemas.generators.SchemaGenerator', # Throttling 'DEFAULT_THROTTLE_RATES': { @@ -144,6 +145,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', 'DEFAULT_SCHEMA_CLASS', + 'DEFAULT_SCHEMA_GENERATOR_CLASS', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_schemas.py b/tests/schemas/test_coreapi.py similarity index 76% rename from tests/test_schemas.py rename to tests/schemas/test_coreapi.py index 6d7091da2..241f36e0c 100644 --- a/tests/test_schemas.py +++ b/tests/schemas/test_coreapi.py @@ -2,15 +2,11 @@ import unittest import pytest from django.conf.urls import include, url -from django.core.exceptions import PermissionDenied -from django.http import Http404 from django.test import TestCase, override_settings -from rest_framework import ( - filters, generics, pagination, permissions, serializers -) -from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path -from rest_framework.decorators import action, api_view, schema +from rest_framework import filters, generics, serializers +from rest_framework.compat import coreapi, coreschema, path +from rest_framework.decorators import action, api_view from rest_framework.request import Request from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.schemas import ( @@ -24,7 +20,8 @@ from rest_framework.utils import formatting from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet, ModelViewSet -from .models import BasicModel, ForeignKeySource +from . import views +from ..models import BasicModel, ForeignKeySource factory = APIRequestFactory() @@ -34,87 +31,6 @@ class MockUser(object): return True -class ExamplePagination(pagination.PageNumberPagination): - page_size = 100 - page_size_query_param = 'page_size' - - -class EmptySerializer(serializers.Serializer): - pass - - -class ExampleSerializer(serializers.Serializer): - a = serializers.CharField(required=True, help_text='A field description') - b = serializers.CharField(required=False) - read_only = serializers.CharField(read_only=True) - hidden = serializers.HiddenField(default='hello') - - -class AnotherSerializerWithDictField(serializers.Serializer): - a = serializers.DictField() - - -class AnotherSerializerWithListFields(serializers.Serializer): - a = serializers.ListField(child=serializers.IntegerField()) - b = serializers.ListSerializer(child=serializers.CharField()) - - -class AnotherSerializer(serializers.Serializer): - c = serializers.CharField(required=True) - d = serializers.CharField(required=False) - - -class ExampleViewSet(ModelViewSet): - pagination_class = ExamplePagination - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - filter_backends = [filters.OrderingFilter] - serializer_class = ExampleSerializer - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializer) - def custom_action(self, request, pk): - """ - A description of custom action. - """ - raise NotImplementedError - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField) - def custom_action_with_dict_field(self, request, pk): - """ - A custom action using a dict field in the serializer. - """ - raise NotImplementedError - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields) - def custom_action_with_list_fields(self, request, pk): - """ - A custom action using both list field and list serializer in the serializer. - """ - raise NotImplementedError - - @action(detail=False) - def custom_list_action(self, request): - raise NotImplementedError - - @action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer) - def custom_list_action_multiple_methods(self, request): - """Custom description.""" - raise NotImplementedError - - @custom_list_action_multiple_methods.mapping.delete - def custom_list_action_multiple_methods_delete(self, request): - """Deletion description.""" - raise NotImplementedError - - @action(detail=False, schema=None) - def excluded_action(self, request): - pass - - def get_serializer(self, *args, **kwargs): - assert self.request - assert self.action - return super(ExampleViewSet, self).get_serializer(*args, **kwargs) - - if coreapi: schema_view = get_schema_view(title='Example API') else: @@ -122,7 +38,7 @@ else: pass router = DefaultRouter() -router.register('example', ExampleViewSet, basename='example') +router.register('example', views.ExampleViewSet, basename='example') urlpatterns = [ url(r'^$', schema_view), url(r'^', include(router.urls)) @@ -130,7 +46,7 @@ urlpatterns = [ @unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(ROOT_URLCONF='tests.test_schemas') +@override_settings(ROOT_URLCONF='tests.schemas.test_coreapi') class TestRouterGeneratedSchema(TestCase): def test_anonymous_request(self): client = APIClient() @@ -299,61 +215,13 @@ class TestRouterGeneratedSchema(TestCase): assert response.data == expected -class DenyAllUsingHttp404(permissions.BasePermission): - - def has_permission(self, request, view): - raise Http404() - - def has_object_permission(self, request, view, obj): - raise Http404() - - -class DenyAllUsingPermissionDenied(permissions.BasePermission): - - def has_permission(self, request, view): - raise PermissionDenied() - - def has_object_permission(self, request, view, obj): - raise PermissionDenied() - - -class Http404ExampleViewSet(ExampleViewSet): - permission_classes = [DenyAllUsingHttp404] - - -class PermissionDeniedExampleViewSet(ExampleViewSet): - permission_classes = [DenyAllUsingPermissionDenied] - - -class MethodLimitedViewSet(ExampleViewSet): - permission_classes = [] - http_method_names = ['get', 'head', 'options'] - - -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): def setUp(self): self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^example/?$', views.ExampleListView.as_view()), + url(r'^example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -404,9 +272,9 @@ class TestSchemaGenerator(TestCase): class TestSchemaGeneratorDjango2(TestCase): def setUp(self): self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), + path('example/', views.ExampleListView.as_view()), + path('example//', views.ExampleDetailView.as_view()), + path('example//sub/', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -456,9 +324,9 @@ class TestSchemaGeneratorDjango2(TestCase): class TestSchemaGeneratorNotAtRoot(TestCase): def setUp(self): self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^api/v1/example/?$', views.ExampleListView.as_view()), + url(r'^api/v1/example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^api/v1/example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -509,7 +377,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase): class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): def setUp(self): router = DefaultRouter() - router.register('example1', MethodLimitedViewSet, basename='example1') + router.register('example1', views.MethodLimitedViewSet, basename='example1') self.patterns = [ url(r'^', include(router.urls)) ] @@ -566,10 +434,10 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): class TestSchemaGeneratorWithRestrictedViewSets(TestCase): def setUp(self): router = DefaultRouter() - router.register('example1', Http404ExampleViewSet, basename='example1') - router.register('example2', PermissionDeniedExampleViewSet, basename='example2') + router.register('example1', views.Http404ExampleViewSet, basename='example1') + router.register('example2', views.PermissionDeniedExampleViewSet, basename='example2') self.patterns = [ - url('^example/?$', ExampleListView.as_view()), + url('^example/?$', views.ExampleListView.as_view()), url(r'^', include(router.urls)) ] @@ -597,29 +465,25 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase): assert schema == expected -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource - fields = ('id', 'name', 'target') - - -class ForeignKeySourceView(generics.CreateAPIView): - queryset = ForeignKeySource.objects.all() - serializer_class = ForeignKeySourceSerializer - - @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGeneratorWithForeignKey(TestCase): - def setUp(self): - self.patterns = [ - url(r'^example/?$', ForeignKeySourceView.as_view()), - ] - def test_schema_for_regular_views(self): """ Ensure that AutoField foreign keys are output as Integer. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) + class ForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + fields = ('id', 'name', 'target') + + class ForeignKeySourceView(generics.CreateAPIView): + queryset = ForeignKeySource.objects.all() + serializer_class = ForeignKeySourceSerializer + + patterns = [ + url(r'^example/?$', ForeignKeySourceView.as_view()), + ] + generator = SchemaGenerator(title='Example API', patterns=patterns) schema = generator.get_schema() expected = coreapi.Document( @@ -653,35 +517,8 @@ class Test4605Regression(TestCase): assert prefix == '/' -class CustomViewInspector(AutoSchema): - """A dummy AutoSchema subclass""" - pass - - class TestAutoSchema(TestCase): - def test_apiview_schema_descriptor(self): - view = APIView() - assert hasattr(view, 'schema') - assert isinstance(view.schema, AutoSchema) - - def test_set_custom_inspector_class_on_view(self): - class CustomView(APIView): - schema = CustomViewInspector() - - view = CustomView() - assert isinstance(view.schema, CustomViewInspector) - - def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): - view = APIView() - assert isinstance(view.schema, CustomViewInspector) - - def test_get_link_requires_instance(self): - descriptor = APIView.schema # Accessed from class - with pytest.raises(AssertionError): - descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') def test_update_fields(self): """ @@ -902,158 +739,19 @@ def test_docstring_is_not_stripped_by_get_description(): assert descr == formatting.dedent(ExampleDocstringAPIView.__doc__[1:][:-1]) -# Views for SchemaGenerationExclusionTests -class ExcludedAPIView(APIView): - schema = None - - def get(self, request, *args, **kwargs): - pass - - -@api_view(['GET']) -@schema(None) -def excluded_fbv(request): - pass - - -@api_view(['GET']) -def included_fbv(request): - pass - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -class SchemaGenerationExclusionTests(TestCase): - def setUp(self): - self.patterns = [ - url('^excluded-cbv/$', ExcludedAPIView.as_view()), - url('^excluded-fbv/$', excluded_fbv), - url('^included-fbv/$', included_fbv), - ] - - def test_schema_generator_excludes_correctly(self): - """Schema should not include excluded views""" - generator = SchemaGenerator(title='Exclusions', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Exclusions', - content={ - 'included-fbv': { - 'list': coreapi.Link(url='/included-fbv/', action='get') - } - } - ) - - assert len(schema.data) == 1 - assert 'included-fbv' in schema.data - assert schema == expected - - def test_endpoint_enumerator_excludes_correctly(self): - """It is responsibility of EndpointEnumerator to exclude views""" - inspector = EndpointEnumerator(self.patterns) - endpoints = inspector.get_api_endpoints() - - assert len(endpoints) == 1 - path, method, callback = endpoints[0] - assert path == '/included-fbv/' - - def test_should_include_endpoint_excludes_correctly(self): - """This is the specific method that should handle the exclusion""" - inspector = EndpointEnumerator(self.patterns) - - # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test - pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) - for pattern in self.patterns] - - should_include = [ - inspector.should_include_endpoint(*pair) for pair in pairs - ] - - expected = [False, False, True] - - assert should_include == expected - - def test_deprecations(self): - with pytest.warns(DeprecationWarning) as record: - @api_view(["GET"], exclude_from_schema=True) - def view(request): - pass - - assert len(record) == 1 - assert str(record[0].message) == ( - "The `exclude_from_schema` argument to `api_view` is deprecated. " - "Use the `schema` decorator instead, passing `None`." - ) - - class OldFashionedExcludedView(APIView): - exclude_from_schema = True - - def get(self, request, *args, **kwargs): - pass - - patterns = [ - url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()), - ] - - inspector = EndpointEnumerator(patterns) - with pytest.warns(DeprecationWarning) as record: - inspector.get_api_endpoints() - - assert len(record) == 1 - assert str(record[0].message) == ( - "The `OldFashionedExcludedView.exclude_from_schema` attribute is " - "deprecated. Set `schema = None` instead." - ) - - -@api_view(["GET"]) -def simple_fbv(request): - pass - - -class BasicModelSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - fields = "__all__" - - -class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - - -class BasicNamingCollisionView(generics.RetrieveAPIView): - queryset = BasicModel.objects.all() - - -class NamingCollisionViewSet(GenericViewSet): - """ - Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/ - """ - permision_class = () - - @action(detail=False) - def detail(self, request): - return {} - - @action(detail=False, url_path='detail/export') - def detail_export(self, request): - return {} - - -naming_collisions_router = SimpleRouter() -naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision") - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') class TestURLNamingCollisions(TestCase): """ Ref: https://github.com/encode/django-rest-framework/issues/4704 """ + @api_view(["GET"]) + def simple_fbv(request): + pass + def test_manually_routing_nested_routes(self): patterns = [ - url(r'^test', simple_fbv), - url(r'^test/list/', simple_fbv), + url(r'^test', self.simple_fbv), + url(r'^test/list/', self.simple_fbv), ] generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) @@ -1088,6 +786,15 @@ class TestURLNamingCollisions(TestCase): assert loc[key].url == url def test_manually_routing_generic_view(self): + class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + fields = "__all__" + + class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView): + queryset = BasicModel.objects.all() + serializer_class = BasicModelSerializer + patterns = [ url(r'^test', NamingCollisionView.as_view()), url(r'^test/retrieve/', NamingCollisionView.as_view()), @@ -1111,6 +818,23 @@ class TestURLNamingCollisions(TestCase): self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0')) def test_from_router(self): + class NamingCollisionViewSet(GenericViewSet): + """ + Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/ + """ + permision_class = () + + @action(detail=False) + def detail(self, request): + return {} + + @action(detail=False, url_path='detail/export') + def detail_export(self, request): + return {} + + naming_collisions_router = SimpleRouter() + naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision") + patterns = [ url(r'from-router', include(naming_collisions_router.urls)), ] @@ -1143,6 +867,9 @@ class TestURLNamingCollisions(TestCase): assert schema == expected def test_url_under_same_key_not_replaced(self): + class BasicNamingCollisionView(generics.RetrieveAPIView): + queryset = BasicModel.objects.all() + patterns = [ url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()), url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()), @@ -1157,8 +884,8 @@ class TestURLNamingCollisions(TestCase): def test_url_under_same_key_not_replaced_another(self): patterns = [ - url(r'^test/list/', simple_fbv), - url(r'^test/(?P\d+)/list/', simple_fbv), + url(r'^test/list/', self.simple_fbv), + url(r'^test/(?P\d+)/list/', self.simple_fbv), ] generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) diff --git a/tests/schemas/test_endpoint_enumerator.py b/tests/schemas/test_endpoint_enumerator.py new file mode 100644 index 000000000..6089a0265 --- /dev/null +++ b/tests/schemas/test_endpoint_enumerator.py @@ -0,0 +1,112 @@ +import unittest + +import pytest +from django.conf.urls import url +from django.test import TestCase + +from rest_framework.compat import coreapi, get_regex_pattern +from rest_framework.decorators import api_view, schema +from rest_framework.schemas.generators import ( + EndpointEnumerator, SchemaGenerator +) +from rest_framework.views import APIView + + +class EndpointExclusionTests(TestCase): + class ExcludedAPIView(APIView): + schema = None + + def get(self, request, *args, **kwargs): + pass + + @api_view(['GET']) + @schema(None) + def excluded_fbv(request): + pass + + @api_view(['GET']) + def included_fbv(request): + pass + + def setUp(self): + self.patterns = [ + url('^excluded-cbv/$', self.ExcludedAPIView.as_view()), + url('^excluded-fbv/$', self.excluded_fbv), + url('^included-fbv/$', self.included_fbv), + ] + + @unittest.skipUnless(coreapi, 'coreapi is not installed') + def test_schema_generator_excludes_correctly(self): + """Schema should not include excluded views""" + generator = SchemaGenerator(title='Exclusions', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( + url='', + title='Exclusions', + content={ + 'included-fbv': { + 'list': coreapi.Link(url='/included-fbv/', action='get') + } + } + ) + + assert len(schema.data) == 1 + assert 'included-fbv' in schema.data + assert schema == expected + + def test_endpoint_enumerator_excludes_correctly(self): + """It is responsibility of EndpointEnumerator to exclude views""" + inspector = EndpointEnumerator(self.patterns) + endpoints = inspector.get_api_endpoints() + + assert len(endpoints) == 1 + path, method, callback = endpoints[0] + assert path == '/included-fbv/' + + def test_should_include_endpoint_excludes_correctly(self): + """This is the specific method that should handle the exclusion""" + inspector = EndpointEnumerator(self.patterns) + + # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test + pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) + for pattern in self.patterns] + + should_include = [ + inspector.should_include_endpoint(*pair) for pair in pairs + ] + + expected = [False, False, True] + + assert should_include == expected + + def test_deprecations(self): + with pytest.warns(DeprecationWarning) as record: + @api_view(["GET"], exclude_from_schema=True) + def view(request): + pass + + assert len(record) == 1 + assert str(record[0].message) == ( + "The `exclude_from_schema` argument to `api_view` is deprecated. " + "Use the `schema` decorator instead, passing `None`." + ) + + class OldFashionedExcludedView(APIView): + exclude_from_schema = True + + def get(self, request, *args, **kwargs): + pass + + patterns = [ + url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()), + ] + + inspector = EndpointEnumerator(patterns) + with pytest.warns(DeprecationWarning) as record: + inspector.get_api_endpoints() + + assert len(record) == 1 + assert str(record[0].message) == ( + "The `OldFashionedExcludedView.exclude_from_schema` attribute is " + "deprecated. Set `schema = None` instead." + ) diff --git a/tests/schemas/test_view_inspector_descriptor.py b/tests/schemas/test_view_inspector_descriptor.py new file mode 100644 index 000000000..6bf2c2bf4 --- /dev/null +++ b/tests/schemas/test_view_inspector_descriptor.py @@ -0,0 +1,38 @@ +import pytest +from django.test import TestCase, override_settings + +from rest_framework.schemas.inspectors import AutoSchema, ViewInspector +from rest_framework.views import APIView + + +class CustomViewInspector(ViewInspector): + """A dummy ViewInspector subclass""" + pass + + +class TestViewInspector(TestCase): + """ + Tests for the descriptor behaviour of ViewInspector + (and subclasses.) + """ + def test_apiview_schema_descriptor(self): + view = APIView() + assert hasattr(view, 'schema') + assert isinstance(view.schema, AutoSchema) + + def test_set_custom_inspector_class_on_view(self): + class CustomView(APIView): + schema = CustomViewInspector() + + view = CustomView() + assert isinstance(view.schema, CustomViewInspector) + + def test_set_custom_inspector_class_via_settings(self): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_view_inspector_descriptor.CustomViewInspector'}): + view = APIView() + assert isinstance(view.schema, CustomViewInspector) + + def test_get_link_requires_instance(self): + descriptor = APIView.schema # Accessed from class + with pytest.raises(AssertionError): + descriptor.get_link(None, None, None) diff --git a/tests/schemas/views.py b/tests/schemas/views.py new file mode 100644 index 000000000..21f6b2aa2 --- /dev/null +++ b/tests/schemas/views.py @@ -0,0 +1,139 @@ +from django.core.exceptions import PermissionDenied +from django.http import Http404 + +from rest_framework import filters, pagination, permissions, serializers +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet + + +# Simple APIViews: +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + +# Classes for ExampleViewSet +class ExamplePagination(pagination.PageNumberPagination): + page_size = 100 + page_size_query_param = 'page_size' + + +class EmptySerializer(serializers.Serializer): + pass + + +class ExampleSerializer(serializers.Serializer): + a = serializers.CharField(required=True, help_text='A field description') + b = serializers.CharField(required=False) + read_only = serializers.CharField(read_only=True) + hidden = serializers.HiddenField(default='hello') + + +class AnotherSerializerWithDictField(serializers.Serializer): + a = serializers.DictField() + + +class AnotherSerializerWithListFields(serializers.Serializer): + a = serializers.ListField(child=serializers.IntegerField()) + b = serializers.ListSerializer(child=serializers.CharField()) + + +class AnotherSerializer(serializers.Serializer): + c = serializers.CharField(required=True) + d = serializers.CharField(required=False) + + +class ExampleViewSet(ModelViewSet): + pagination_class = ExamplePagination + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + filter_backends = [filters.OrderingFilter] + serializer_class = ExampleSerializer + + @action(methods=['post'], detail=True, serializer_class=AnotherSerializer) + def custom_action(self, request, pk): + """ + A description of custom action. + """ + raise NotImplementedError + + @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField) + def custom_action_with_dict_field(self, request, pk): + """ + A custom action using a dict field in the serializer. + """ + raise NotImplementedError + + @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields) + def custom_action_with_list_fields(self, request, pk): + """ + A custom action using both list field and list serializer in the serializer. + """ + raise NotImplementedError + + @action(detail=False) + def custom_list_action(self, request): + raise NotImplementedError + + @action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer) + def custom_list_action_multiple_methods(self, request): + """Custom description.""" + raise NotImplementedError + + @custom_list_action_multiple_methods.mapping.delete + def custom_list_action_multiple_methods_delete(self, request): + """Deletion description.""" + raise NotImplementedError + + @action(detail=False, schema=None) + def excluded_action(self, request): + pass + + def get_serializer(self, *args, **kwargs): + assert self.request + assert self.action + return super(ExampleViewSet, self).get_serializer(*args, **kwargs) + + +# ExampleViewSet subclasses +class DenyAllUsingHttp404(permissions.BasePermission): + + def has_permission(self, request, view): + raise Http404() + + def has_object_permission(self, request, view, obj): + raise Http404() + + +class DenyAllUsingPermissionDenied(permissions.BasePermission): + + def has_permission(self, request, view): + raise PermissionDenied() + + def has_object_permission(self, request, view, obj): + raise PermissionDenied() + + +class Http404ExampleViewSet(ExampleViewSet): + permission_classes = [DenyAllUsingHttp404] + + +class PermissionDeniedExampleViewSet(ExampleViewSet): + permission_classes = [DenyAllUsingPermissionDenied] + + +class MethodLimitedViewSet(ExampleViewSet): + permission_classes = [] + http_method_names = ['get', 'head', 'options']