add support for tag objects

This commit is contained in:
Dhaval Mehta 2020-02-09 13:14:26 +05:30
parent bb339f4947
commit 44c1c25bde

View File

@ -43,6 +43,8 @@ class SchemaGenerator(BaseSchemaGenerator):
# Iterate endpoints generating per method path operations.
# TODO: …and reference components.
paths = {}
tags = []
processed_views_for_tags = set()
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
@ -57,11 +59,16 @@ class SchemaGenerator(BaseSchemaGenerator):
paths.setdefault(path, {})
paths[path][method.lower()] = operation
if view.__class__.__name__ not in processed_views_for_tags:
tags.extend(view.schema.get_tag_objects())
processed_views_for_tags.add(view.__class__.__name__)
# Compile final schema.
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
'tags': tags
}
return schema
@ -73,7 +80,10 @@ class SchemaGenerator(BaseSchemaGenerator):
class AutoSchema(ViewInspector):
def __init__(self, tags=None):
self.tags = tags
if tags is None:
tags = []
self._tag_objects = list(filter(lambda tag: isinstance(tag, (dict, OrderedDict)), tags))
self._tags = list(map(lambda tag: tag['name'] if isinstance(tag, (dict, OrderedDict)) else tag, tags))
super().__init__()
request_media_types = []
@ -103,10 +113,13 @@ class AutoSchema(ViewInspector):
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
operation['tags'] = self.get_tags(path, method)
operation['tags'] = self._get_tags(path, method)
return operation
def get_tag_objects(self):
return self._tag_objects
def _get_operation_id(self, path, method):
"""
Compute an operation ID from the model, serializer or view name.
@ -571,14 +584,10 @@ class AutoSchema(ViewInspector):
}
}
def get_tags(self, path, method):
def _get_tags(self, path, method):
# If user have specified tags, use them.
if self.tags:
if isinstance(self.tags, (list, set, tuple)):
return self.tags
if isinstance(self.tags, (dict, OrderedDict)):
return self.tags
raise ValueError('tags must be dict or list.')
if self._tags:
return self._tags
# Extract tag from viewset name
# UserViewSet tags = [User]