From 44c1c25bdef8086a8c47d2f4a2e87fa5cb26e1c6 Mon Sep 17 00:00:00 2001 From: Dhaval Mehta <20968146+dhaval-mehta@users.noreply.github.com> Date: Sun, 9 Feb 2020 13:14:26 +0530 Subject: [PATCH] add support for tag objects --- rest_framework/schemas/openapi.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 33bc31ac1..d49a7e14e 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -43,6 +43,8 @@ class SchemaGenerator(BaseSchemaGenerator): # Iterate endpoints generating per method path operations. # TODO: …and reference components. paths = {} + tags = [] + processed_views_for_tags = set() _, view_endpoints = self._get_paths_and_endpoints(None if public else request) for path, method, view in view_endpoints: if not self.has_view_permissions(path, method, view): @@ -57,11 +59,16 @@ class SchemaGenerator(BaseSchemaGenerator): paths.setdefault(path, {}) paths[path][method.lower()] = operation + if view.__class__.__name__ not in processed_views_for_tags: + tags.extend(view.schema.get_tag_objects()) + processed_views_for_tags.add(view.__class__.__name__) + # Compile final schema. schema = { 'openapi': '3.0.2', 'info': self.get_info(), 'paths': paths, + 'tags': tags } return schema @@ -73,7 +80,10 @@ class SchemaGenerator(BaseSchemaGenerator): class AutoSchema(ViewInspector): def __init__(self, tags=None): - self.tags = tags + if tags is None: + tags = [] + self._tag_objects = list(filter(lambda tag: isinstance(tag, (dict, OrderedDict)), tags)) + self._tags = list(map(lambda tag: tag['name'] if isinstance(tag, (dict, OrderedDict)) else tag, tags)) super().__init__() request_media_types = [] @@ -103,10 +113,13 @@ class AutoSchema(ViewInspector): if request_body: operation['requestBody'] = request_body operation['responses'] = self._get_responses(path, method) - operation['tags'] = self.get_tags(path, method) + operation['tags'] = self._get_tags(path, method) return operation + def get_tag_objects(self): + return self._tag_objects + def _get_operation_id(self, path, method): """ Compute an operation ID from the model, serializer or view name. @@ -571,14 +584,10 @@ class AutoSchema(ViewInspector): } } - def get_tags(self, path, method): + def _get_tags(self, path, method): # If user have specified tags, use them. - if self.tags: - if isinstance(self.tags, (list, set, tuple)): - return self.tags - if isinstance(self.tags, (dict, OrderedDict)): - return self.tags - raise ValueError('tags must be dict or list.') + if self._tags: + return self._tags # Extract tag from viewset name # UserViewSet tags = [User]