mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-03 13:14:30 +03:00
Pass request to schema generation (#4383)
Pass request to schema generation
This commit is contained in:
parent
3698d9ea2e
commit
b50d8950ee
|
@ -65,44 +65,52 @@ class SchemaGenerator(object):
|
|||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
patterns = urls.urlpatterns
|
||||
self.patterns = urls.urlpatterns
|
||||
elif patterns is None and urlconf is None:
|
||||
urls = import_module(settings.ROOT_URLCONF)
|
||||
patterns = urls.urlpatterns
|
||||
self.patterns = urls.urlpatterns
|
||||
else:
|
||||
self.patterns = patterns
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.endpoints = self.get_api_endpoints(patterns)
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None):
|
||||
if request is None:
|
||||
endpoints = self.endpoints
|
||||
else:
|
||||
# Filter the list of endpoints to only include those that
|
||||
# the user has permission on.
|
||||
endpoints = []
|
||||
for key, link, callback in self.endpoints:
|
||||
method = link.action.upper()
|
||||
view = callback.cls()
|
||||
if self.endpoints is None:
|
||||
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||
|
||||
links = []
|
||||
for key, path, method, callback in self.endpoints:
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
view.format_kwarg = None
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
pass
|
||||
else:
|
||||
endpoints.append((key, link, callback))
|
||||
continue
|
||||
else:
|
||||
view.request = None
|
||||
|
||||
if not endpoints:
|
||||
link = self.get_link(path, method, callback, view)
|
||||
links.append((key, link))
|
||||
|
||||
if not link:
|
||||
return None
|
||||
|
||||
# Generate the schema content structure, from the endpoints.
|
||||
# ('users', 'list'), Link -> {'users': {'list': Link()}}
|
||||
content = {}
|
||||
for key, link, callback in endpoints:
|
||||
for key, link in links:
|
||||
insert_into(content, key, link)
|
||||
|
||||
# Return the schema document.
|
||||
|
@ -122,8 +130,7 @@ class SchemaGenerator(object):
|
|||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
key = self.get_key(path, method, callback)
|
||||
link = self.get_link(path, method, callback)
|
||||
endpoint = (key, link, callback)
|
||||
endpoint = (key, path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
|
@ -190,14 +197,10 @@ class SchemaGenerator(object):
|
|||
|
||||
# Methods for generating each individual `Link` instance...
|
||||
|
||||
def get_link(self, path, method, callback):
|
||||
def get_link(self, path, method, callback, view):
|
||||
"""
|
||||
Return a `coreapi.Link` instance for the given endpoint.
|
||||
"""
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
|
||||
fields = self.get_path_fields(path, method, callback, view)
|
||||
fields += self.get_serializer_fields(path, method, callback, view)
|
||||
fields += self.get_pagination_fields(path, method, callback, view)
|
||||
|
@ -260,20 +263,18 @@ class SchemaGenerator(object):
|
|||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer_class'):
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
|
||||
serializer_class = view.get_serializer_class()
|
||||
serializer = serializer_class()
|
||||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return coreapi.Field(name='data', location='body', required=True)
|
||||
return [coreapi.Field(name='data', location='body', required=True)]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only:
|
||||
continue
|
||||
|
|
|
@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
|
|||
def custom_action(self, request, pk):
|
||||
return super(ExampleSerializer, self).retrieve(self, request)
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
assert self.request
|
||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||
|
||||
|
||||
class ExampleView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
|
Loading…
Reference in New Issue
Block a user