diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 8794c9967..92cec8aa1 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -308,14 +308,7 @@ class SchemaGenerator(object): """ links = LinkNode() - # Generate (path, method, view) given (path, method, callback). - paths = [] - view_endpoints = [] - for path, method, callback in self.endpoints: - view = self.create_view(callback, method, request) - path = self.coerce_path(path, method, view) - paths.append(path) - view_endpoints.append((path, method, view)) + paths, view_endpoints = self._get_paths_and_endpoints(request) # Only generate the path prefix for paths that will be included if not paths: @@ -332,6 +325,20 @@ class SchemaGenerator(object): return links + def _get_paths_and_endpoints(self, request): + """ + Generate (path, method, view) given (path, method, callback) for paths. + """ + paths = [] + view_endpoints = [] + for path, method, callback in self.endpoints: + view = self.create_view(callback, method, request) + path = self.coerce_path(path, method, view) + paths.append(path) + view_endpoints.append((path, method, view)) + + return paths, view_endpoints + # Methods used when we generate a view instance from the raw callback... def determine_path_prefix(self, paths): @@ -461,3 +468,26 @@ class SchemaGenerator(object): # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] + + +class OpenAPISchemaGenerator(SchemaGenerator): + + def get_paths(self, request=None): + result = OrderedDict() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + operation = view.schema.get_operation(path, method) + subpath = path[len(prefix):] + result.setdefault(subpath, {}) + result[subpath][method.lower()] = operation + + return result diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index b90f60e08..52f0fe3d6 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -501,3 +501,10 @@ class DefaultSchema(ViewInspector): inspector = inspector_class() inspector.view = instance return inspector + + +class OpenAPIAutoSchema(ViewInspector): + + def get_operation(self, path, method): + # TODO: fill in details here. + return {} diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py new file mode 100644 index 000000000..bdd5dc9ba --- /dev/null +++ b/tests/schemas/test_openapi.py @@ -0,0 +1,65 @@ +from django.conf.urls import url +from django.test import RequestFactory, TestCase, override_settings + +from rest_framework.request import Request +from rest_framework.schemas.generators import OpenAPISchemaGenerator +from rest_framework.schemas.inspectors import OpenAPIAutoSchema + +from . import views + + +def create_request(path): + factory = RequestFactory() + request = Request(factory.get(path)) + return request + + +def create_view(view_cls, method, request): + generator = OpenAPISchemaGenerator() + view = generator.create_view(view_cls.as_view(), method, request) + return view + + +class TestInspector(TestCase): + + def test_path_without_parameters(self): + path = '/example/' + method = 'GET' + + view = create_view( + views.ExampleListView, + method, + create_request(path) + ) + inspector = OpenAPIAutoSchema() + inspector.view = view + + operation = inspector.get_operation(path, method) + assert operation == {} + + # TODO: parameters, operationID, responses, etc ??? + + +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.inspectors.OpenAPIAutoSchema'}) +class TestGenerator(TestCase): + + def test_override_settings(self): + assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema) + + def test_paths_construction(self): + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = OpenAPISchemaGenerator(patterns=patterns) + + # This happens in get_schema() + inspector = generator.endpoint_inspector_cls(generator.patterns, generator.urlconf) + generator.endpoints = inspector.get_api_endpoints() + + paths = generator.get_paths() + + assert 'example/' in paths + example_operations = paths['example/'] + assert len(example_operations) == 2 + assert 'get' in example_operations + assert 'post' in example_operations diff --git a/tests/schemas/test_schemas.py b/tests/schemas/test_schemas.py index 0c7540427..ddee68a1f 100644 --- a/tests/schemas/test_schemas.py +++ b/tests/schemas/test_schemas.py @@ -24,6 +24,7 @@ from rest_framework.utils import formatting from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet, ModelViewSet +from . import views from ..models import BasicModel, ForeignKeySource factory = APIRequestFactory() @@ -330,30 +331,13 @@ class MethodLimitedViewSet(ExampleViewSet): http_method_names = ['get', 'head', 'options'] -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): def setUp(self): self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^example/?$', views.ExampleListView.as_view()), + url(r'^example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -404,9 +388,9 @@ class TestSchemaGenerator(TestCase): class TestSchemaGeneratorDjango2(TestCase): def setUp(self): self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), + path('example/', views.ExampleListView.as_view()), + path('example//', views.ExampleDetailView.as_view()), + path('example//sub/', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -456,9 +440,9 @@ class TestSchemaGeneratorDjango2(TestCase): class TestSchemaGeneratorNotAtRoot(TestCase): def setUp(self): self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^api/v1/example/?$', views.ExampleListView.as_view()), + url(r'^api/v1/example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^api/v1/example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -569,7 +553,7 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase): router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example2', PermissionDeniedExampleViewSet, basename='example2') self.patterns = [ - url('^example/?$', ExampleListView.as_view()), + url('^example/?$', views.ExampleListView.as_view()), url(r'^', include(router.urls)) ] diff --git a/tests/schemas/views.py b/tests/schemas/views.py new file mode 100644 index 000000000..c368ba7e5 --- /dev/null +++ b/tests/schemas/views.py @@ -0,0 +1,19 @@ +from rest_framework import permissions +from rest_framework.views import APIView + + +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass