Pass request to schema generation

This commit is contained in:
Tom Christie 2016-08-11 10:59:27 +01:00
parent f16e880167
commit 9161b53e4a

View File

@ -65,44 +65,52 @@ class SchemaGenerator(object):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
else:
self.patterns = patterns
if url and not url.endswith('/'):
url += '/'
self.title = title
self.url = url
self.endpoints = self.get_api_endpoints(patterns)
self.endpoints = None
def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)
links = []
for key, path, method, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None
if request is not None:
view.request = clone_request(request, method)
view.format_kwarg = None
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
continue
else:
view.request = None
if not endpoints:
link = self.get_link(path, method, callback, view)
links.append((key, link))
if not link:
return None
# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
for key, link in links:
insert_into(content, key, link)
# Return the schema document.
@ -122,8 +130,7 @@ class SchemaGenerator(object):
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
endpoint = (key, path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
@ -190,14 +197,10 @@ class SchemaGenerator(object):
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, callback):
def get_link(self, path, method, callback, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)