add tag_objects

This commit is contained in:
Dhaval Mehta 2020-05-26 01:33:05 +05:30
parent acbd9d8222
commit 9fb504f4af

View File

@ -3,6 +3,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from decimal import Decimal from decimal import Decimal
from operator import attrgetter from operator import attrgetter
from typing import List, Dict
from urllib.parse import urljoin from urllib.parse import urljoin
from django.core.validators import ( from django.core.validators import (
@ -18,7 +19,6 @@ from rest_framework import (
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
@ -26,6 +26,14 @@ from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator): class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None,
tag_objects: List[Dict] = None):
self.tag_objects = tag_objects
super().__init__(title, url, description, patterns, urlconf, version)
if tag_objects:
self.tag_objects = tag_objects
def get_info(self): def get_info(self):
# Title and version are required by openapi specification 3.x # Title and version are required by openapi specification 3.x
info = { info = {
@ -106,6 +114,9 @@ class SchemaGenerator(BaseSchemaGenerator):
'paths': paths, 'paths': paths,
} }
if self.tag_objects:
schema['tags'] = self.tag_objects
if len(components_schemas) > 0: if len(components_schemas) > 0:
schema['components'] = { schema['components'] = {
'schemas': components_schemas 'schemas': components_schemas
@ -113,9 +124,8 @@ class SchemaGenerator(BaseSchemaGenerator):
return schema return schema
# View Inspectors # View Inspectors
class AutoSchema(ViewInspector): class AutoSchema(ViewInspector):
def __init__(self, tags=None, operation_id_base=None, component_name=None): def __init__(self, tags=None, operation_id_base=None, component_name=None):
@ -179,7 +189,7 @@ class AutoSchema(ViewInspector):
raise Exception( raise Exception(
'"{}" is an invalid class name for schema generation. ' '"{}" is an invalid class name for schema generation. '
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"' 'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
.format(serializer.__class__.__name__) .format(serializer.__class__.__name__)
) )
return component_name return component_name