Pass request to schema generation

This commit is contained in:
Tom Christie 2016-08-11 10:59:27 +01:00
parent f16e880167
commit 9161b53e4a

View File

@ -65,44 +65,52 @@ class SchemaGenerator(object):
urls = import_module(urlconf) urls = import_module(urlconf)
else: else:
urls = urlconf urls = urlconf
patterns = urls.urlpatterns self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None: elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF) urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns self.patterns = urls.urlpatterns
else:
self.patterns = patterns
if url and not url.endswith('/'): if url and not url.endswith('/'):
url += '/' url += '/'
self.title = title self.title = title
self.url = url self.url = url
self.endpoints = self.get_api_endpoints(patterns) self.endpoints = None
def get_schema(self, request=None): def get_schema(self, request=None):
if request is None: if self.endpoints is None:
endpoints = self.endpoints self.endpoints = self.get_api_endpoints(self.patterns)
else:
# Filter the list of endpoints to only include those that links = []
# the user has permission on. for key, path, method, callback in self.endpoints:
endpoints = [] view = callback.cls()
for key, link, callback in self.endpoints: for attr, val in getattr(callback, 'initkwargs', {}).items():
method = link.action.upper() setattr(view, attr, val)
view = callback.cls() view.args = ()
view.kwargs = {}
view.format_kwarg = None
if request is not None:
view.request = clone_request(request, method) view.request = clone_request(request, method)
view.format_kwarg = None
try: try:
view.check_permissions(view.request) view.check_permissions(view.request)
except exceptions.APIException: except exceptions.APIException:
pass continue
else: else:
endpoints.append((key, link, callback)) view.request = None
if not endpoints: link = self.get_link(path, method, callback, view)
links.append((key, link))
if not link:
return None return None
# Generate the schema content structure, from the endpoints. # Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}} # ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {} content = {}
for key, link, callback in endpoints: for key, link in links:
insert_into(content, key, link) insert_into(content, key, link)
# Return the schema document. # Return the schema document.
@ -122,8 +130,7 @@ class SchemaGenerator(object):
if self.should_include_endpoint(path, callback): if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback): for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback) key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback) endpoint = (key, path, method, callback)
endpoint = (key, link, callback)
api_endpoints.append(endpoint) api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver): elif isinstance(pattern, RegexURLResolver):
@ -190,14 +197,10 @@ class SchemaGenerator(object):
# Methods for generating each individual `Link` instance... # Methods for generating each individual `Link` instance...
def get_link(self, path, method, callback): def get_link(self, path, method, callback, view):
""" """
Return a `coreapi.Link` instance for the given endpoint. Return a `coreapi.Link` instance for the given endpoint.
""" """
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
fields = self.get_path_fields(path, method, callback, view) fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view) fields += self.get_pagination_fields(path, method, callback, view)