Pass request to schema generation (#4383)

Pass request to schema generation
This commit is contained in:
Tom Christie 2016-08-11 11:27:28 +01:00 committed by GitHub
parent 3698d9ea2e
commit b50d8950ee
2 changed files with 36 additions and 31 deletions

View File

@ -65,44 +65,52 @@ class SchemaGenerator(object):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
else:
self.patterns = patterns
if url and not url.endswith('/'):
url += '/'
self.title = title
self.url = url
self.endpoints = self.get_api_endpoints(patterns)
self.endpoints = None
def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)
links = []
for key, path, method, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None
if request is not None:
view.request = clone_request(request, method)
view.format_kwarg = None
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
continue
else:
view.request = None
if not endpoints:
link = self.get_link(path, method, callback, view)
links.append((key, link))
if not link:
return None
# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
for key, link in links:
insert_into(content, key, link)
# Return the schema document.
@ -122,8 +130,7 @@ class SchemaGenerator(object):
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
endpoint = (key, path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
@ -190,14 +197,10 @@ class SchemaGenerator(object):
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, callback):
def get_link(self, path, method, callback, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)
@ -260,20 +263,18 @@ class SchemaGenerator(object):
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer_class'):
if not hasattr(view, 'get_serializer'):
return []
fields = []
serializer_class = view.get_serializer_class()
serializer = serializer_class()
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return coreapi.Field(name='data', location='body', required=True)
return [coreapi.Field(name='data', location='body', required=True)]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only:
continue

View File

@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)
def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]