diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 48fb2a392..ac81782a9 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -1,3 +1,4 @@ +import os import re from collections import OrderedDict from importlib import import_module @@ -237,18 +238,63 @@ class SchemaGenerator(object): included in the API schema. """ links = OrderedDict() + + # Generate (path, method, view) given (path, method, callback). + paths = [] + view_endpoints = [] for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) + if getattr(view, 'exclude_from_schema', False): + continue path = self.coerce_path(path, method, view) - if not self.should_include_view(path, method, view): + paths.append(path) + view_endpoints.append((path, method, view)) + + # Only generate the path prefix for paths that will be included + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): continue link = self.get_link(path, method, view) - keys = self.get_keys(path, method, view) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) insert_into(links, keys, link) return links # Methods used when we generate a view instance from the raw callback... + def determine_path_prefix(self, paths): + """ + Given a list of all paths, return the common prefix which should be + discounted when generating a schema structure. + + This will be the longest common string that does not include that last + component of the URL, or the last component before a path parameter. + + For example: + + /api/v1/users/ + /api/v1/users/{pk}/ + + The path prefix is '/api/v1/' + """ + prefixes = [] + for path in paths: + components = path.strip('/').split('/') + initial_components = [] + for component in components: + if '{' in component: + break + initial_components.append(component) + prefix = '/'.join(initial_components[:-1]) + if not prefix: + # We can just break early in the case that there's at least + # one URL that doesn't have a path prefix. + return '/' + prefixes.append('/' + prefix + '/') + return os.path.commonprefix(prefixes) + def create_view(self, callback, method, request=None): """ Given a callback, return an actual view instance. @@ -274,13 +320,10 @@ class SchemaGenerator(object): return view - def should_include_view(self, path, method, view): + def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. """ - if getattr(view, 'exclude_from_schema', False): - return False - if view.request is None: return True @@ -291,6 +334,11 @@ class SchemaGenerator(object): return True def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ if not self.coerce_pk or '{pk}' not in path: return path model = getattr(getattr(view, 'queryset', None), 'model', None) @@ -461,7 +509,7 @@ class SchemaGenerator(object): # Method for generating the link layout.... - def get_keys(self, path, method, view): + def get_keys(self, subpath, method, view): """ Return a list of keys that should be used to layout a link within the schema document. @@ -478,14 +526,14 @@ class SchemaGenerator(object): action = view.action else: # Views have no associated action, so we determine one from the method. - if is_list_view(path, method, view): + if is_list_view(subpath, method, view): action = 'list' else: action = self.default_mapping[method.lower()] named_path_components = [ component for component - in path.strip('/').split('/') + in subpath.strip('/').split('/') if '{' not in component ] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 330f6f4d5..0a422f078 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -280,3 +280,56 @@ class TestSchemaGenerator(TestCase): } ) self.assertEqual(schema, expected) + + +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaGeneratorNotAtRoot(TestCase): + def setUp(self): + self.patterns = [ + url('^api/v1/example/?$', ExampleListView.as_view()), + url('^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), + url('^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + ] + + def test_schema_for_regular_views(self): + """ + Ensure that schema generation with an API that is not at the URL + root continues to use correct structure for link keys. + """ + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( + url='', + title='Example API', + content={ + 'example': { + 'create': coreapi.Link( + url='/api/v1/example/', + action='post', + fields=[] + ), + 'list': coreapi.Link( + url='/api/v1/example/', + action='get', + fields=[] + ), + 'retrieve': coreapi.Link( + url='/api/v1/example/{id}/', + action='get', + fields=[ + coreapi.Field('id', required=True, location='path') + ] + ), + 'sub': { + 'list': coreapi.Link( + url='/api/v1/example/{id}/sub/', + action='get', + fields=[ + coreapi.Field('id', required=True, location='path') + ] + ) + } + } + } + ) + self.assertEqual(schema, expected)