mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-10 16:40:55 +03:00
Pass request to schema generation (#4383)
Pass request to schema generation
This commit is contained in:
parent
3698d9ea2e
commit
b50d8950ee
|
@ -65,44 +65,52 @@ class SchemaGenerator(object):
|
||||||
urls = import_module(urlconf)
|
urls = import_module(urlconf)
|
||||||
else:
|
else:
|
||||||
urls = urlconf
|
urls = urlconf
|
||||||
patterns = urls.urlpatterns
|
self.patterns = urls.urlpatterns
|
||||||
elif patterns is None and urlconf is None:
|
elif patterns is None and urlconf is None:
|
||||||
urls = import_module(settings.ROOT_URLCONF)
|
urls = import_module(settings.ROOT_URLCONF)
|
||||||
patterns = urls.urlpatterns
|
self.patterns = urls.urlpatterns
|
||||||
|
else:
|
||||||
|
self.patterns = patterns
|
||||||
|
|
||||||
if url and not url.endswith('/'):
|
if url and not url.endswith('/'):
|
||||||
url += '/'
|
url += '/'
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
self.url = url
|
self.url = url
|
||||||
self.endpoints = self.get_api_endpoints(patterns)
|
self.endpoints = None
|
||||||
|
|
||||||
def get_schema(self, request=None):
|
def get_schema(self, request=None):
|
||||||
if request is None:
|
if self.endpoints is None:
|
||||||
endpoints = self.endpoints
|
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||||
else:
|
|
||||||
# Filter the list of endpoints to only include those that
|
links = []
|
||||||
# the user has permission on.
|
for key, path, method, callback in self.endpoints:
|
||||||
endpoints = []
|
|
||||||
for key, link, callback in self.endpoints:
|
|
||||||
method = link.action.upper()
|
|
||||||
view = callback.cls()
|
view = callback.cls()
|
||||||
view.request = clone_request(request, method)
|
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||||
|
setattr(view, attr, val)
|
||||||
|
view.args = ()
|
||||||
|
view.kwargs = {}
|
||||||
view.format_kwarg = None
|
view.format_kwarg = None
|
||||||
|
|
||||||
|
if request is not None:
|
||||||
|
view.request = clone_request(request, method)
|
||||||
try:
|
try:
|
||||||
view.check_permissions(view.request)
|
view.check_permissions(view.request)
|
||||||
except exceptions.APIException:
|
except exceptions.APIException:
|
||||||
pass
|
continue
|
||||||
else:
|
else:
|
||||||
endpoints.append((key, link, callback))
|
view.request = None
|
||||||
|
|
||||||
if not endpoints:
|
link = self.get_link(path, method, callback, view)
|
||||||
|
links.append((key, link))
|
||||||
|
|
||||||
|
if not link:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Generate the schema content structure, from the endpoints.
|
# Generate the schema content structure, from the endpoints.
|
||||||
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
||||||
content = {}
|
content = {}
|
||||||
for key, link, callback in endpoints:
|
for key, link in links:
|
||||||
insert_into(content, key, link)
|
insert_into(content, key, link)
|
||||||
|
|
||||||
# Return the schema document.
|
# Return the schema document.
|
||||||
|
@ -122,8 +130,7 @@ class SchemaGenerator(object):
|
||||||
if self.should_include_endpoint(path, callback):
|
if self.should_include_endpoint(path, callback):
|
||||||
for method in self.get_allowed_methods(callback):
|
for method in self.get_allowed_methods(callback):
|
||||||
key = self.get_key(path, method, callback)
|
key = self.get_key(path, method, callback)
|
||||||
link = self.get_link(path, method, callback)
|
endpoint = (key, path, method, callback)
|
||||||
endpoint = (key, link, callback)
|
|
||||||
api_endpoints.append(endpoint)
|
api_endpoints.append(endpoint)
|
||||||
|
|
||||||
elif isinstance(pattern, RegexURLResolver):
|
elif isinstance(pattern, RegexURLResolver):
|
||||||
|
@ -190,14 +197,10 @@ class SchemaGenerator(object):
|
||||||
|
|
||||||
# Methods for generating each individual `Link` instance...
|
# Methods for generating each individual `Link` instance...
|
||||||
|
|
||||||
def get_link(self, path, method, callback):
|
def get_link(self, path, method, callback, view):
|
||||||
"""
|
"""
|
||||||
Return a `coreapi.Link` instance for the given endpoint.
|
Return a `coreapi.Link` instance for the given endpoint.
|
||||||
"""
|
"""
|
||||||
view = callback.cls()
|
|
||||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
|
||||||
setattr(view, attr, val)
|
|
||||||
|
|
||||||
fields = self.get_path_fields(path, method, callback, view)
|
fields = self.get_path_fields(path, method, callback, view)
|
||||||
fields += self.get_serializer_fields(path, method, callback, view)
|
fields += self.get_serializer_fields(path, method, callback, view)
|
||||||
fields += self.get_pagination_fields(path, method, callback, view)
|
fields += self.get_pagination_fields(path, method, callback, view)
|
||||||
|
@ -260,20 +263,18 @@ class SchemaGenerator(object):
|
||||||
if method not in ('PUT', 'PATCH', 'POST'):
|
if method not in ('PUT', 'PATCH', 'POST'):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not hasattr(view, 'get_serializer_class'):
|
if not hasattr(view, 'get_serializer'):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
fields = []
|
serializer = view.get_serializer()
|
||||||
|
|
||||||
serializer_class = view.get_serializer_class()
|
|
||||||
serializer = serializer_class()
|
|
||||||
|
|
||||||
if isinstance(serializer, serializers.ListSerializer):
|
if isinstance(serializer, serializers.ListSerializer):
|
||||||
return coreapi.Field(name='data', location='body', required=True)
|
return [coreapi.Field(name='data', location='body', required=True)]
|
||||||
|
|
||||||
if not isinstance(serializer, serializers.Serializer):
|
if not isinstance(serializer, serializers.Serializer):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
fields = []
|
||||||
for field in serializer.fields.values():
|
for field in serializer.fields.values():
|
||||||
if field.read_only:
|
if field.read_only:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
|
||||||
def custom_action(self, request, pk):
|
def custom_action(self, request, pk):
|
||||||
return super(ExampleSerializer, self).retrieve(self, request)
|
return super(ExampleSerializer, self).retrieve(self, request)
|
||||||
|
|
||||||
|
def get_serializer(self, *args, **kwargs):
|
||||||
|
assert self.request
|
||||||
|
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ExampleView(APIView):
|
class ExampleView(APIView):
|
||||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user