enable schema generation for APIView via decorator

This commit is contained in:
Maxim Kurnikov 2019-09-12 17:59:50 +03:00
parent 89ac0a1c7e
commit beb8c665bd
2 changed files with 126 additions and 5 deletions

View File

@ -1,4 +1,7 @@
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Dict, Any, Type
from urllib.parse import urljoin
from django.core.validators import (
@ -7,6 +10,7 @@ from django.core.validators import (
)
from django.db import models
from django.utils.encoding import force_str
from rest_framework.serializers import Serializer
from rest_framework import exceptions, serializers
from rest_framework.compat import uritemplate
@ -16,6 +20,7 @@ from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Generator
@ -73,11 +78,20 @@ class SchemaGenerator(BaseSchemaGenerator):
return schema
# View Inspectors
# dictionary of
# label: str # endpoint name
# description: str # docs
# query_params: Serializer # for request query_params
# data: Serializer # for request data
# files: Serializer # for request files
# responses: Dict[int, Serializer] # responses
_MethodData = Dict[str, Any]
class AutoSchema(ViewInspector):
content_types = ['application/json']
method_mapping = {
'get': 'Retrieve',
@ -87,10 +101,43 @@ class AutoSchema(ViewInspector):
'delete': 'Destroy',
}
def __init__(self):
super().__init__()
self.apiview_methods = defaultdict(dict) # type: Dict[str, _MethodData]
def _set_apiview_method_data(self, method: str,
*,
label: Optional[str] = None,
description: Optional[str] = None,
query_params_serializer: Optional[Serializer] = None,
data_serializer: Optional[Serializer] = None,
responses: Optional[Dict[int, Serializer]] = None) -> None:
self.apiview_methods[method] = {
'label': label,
'description': description,
'query_params_serializer': query_params_serializer,
'data_serializer': data_serializer,
'responses': responses
}
@contextmanager
def _replace_get_serializer(self, serializer: Optional[Serializer]) -> None:
if serializer is not None:
old_get_serializer = self._get_serializer
def always_returns_serializer(method, path):
return serializer
self._get_serializer = always_returns_serializer
yield
if serializer is not None:
self._get_serializer = old_get_serializer
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self._get_operation_id(path, method)
operation['operationId'] = self.apiview_methods[method].get('label') or self._get_operation_id(path, method)
operation['description'] = self.apiview_methods[method].get('description') or ''
parameters = []
parameters += self._get_path_parameters(path, method)
@ -98,10 +145,14 @@ class AutoSchema(ViewInspector):
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self._get_request_body(path, method)
with self._replace_get_serializer(self.apiview_methods[method].get('data_serializer')):
request_body = self._get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
with self._replace_get_serializer((self.apiview_methods[method].get('responses') or {}).get(200)):
operation['responses'] = self._get_responses(path, method)
return operation
@ -520,3 +571,39 @@ class AutoSchema(ViewInspector):
'description': ""
}
}
class generate_schema:
def __init__(self, *,
label: Optional[str] = None,
description: Optional[str] = None,
query_params: Optional[Type[Serializer]] = None,
data: Optional[Type[Serializer]] = None,
responses: Optional[Dict[int, Type[Serializer]]] = None):
self.label = label
self.description = description
self.query_params_serializer = query_params
self.data_serializer = data
self.responses = responses
def __call__(self, method):
_self = self
class Wrapper:
def __set_name__(self, owner, name):
# hack to bypass __get__ method of the DefaultSchema.
# i don't want to think of or change DefaultSchema internals for now
# no need to really call __get__ again
owner.schema = owner.schema
if isinstance(owner.schema, AutoSchema):
owner.schema._set_apiview_method_data(method.__name__.upper(),
label=_self.label,
description=_self.description or method.__doc__,
query_params_serializer=_self.query_params_serializer,
data_serializer=_self.data_serializer,
responses=_self.responses)
def __call__(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)
return Wrapper()

View File

@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.compat import uritemplate
from rest_framework.request import Request
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator, generate_schema
from . import views
@ -570,3 +570,37 @@ class TestGenerator(TestCase):
assert schema['info']['title'] == 'My title'
assert schema['info']['version'] == '1.2.3'
assert schema['info']['description'] == 'My description'
def test_generate_schema_with_decorator(self):
class MySerializer(serializers.Serializer):
username = serializers.CharField()
class MyView(views.APIView):
@generate_schema(label='My Label',
description='My Description',
data=MySerializer())
def post(self, request):
pass
patterns = [
url(r'^example/?$', MyView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert schema['paths']['/example/']['post']['operationId'] == 'My Label'
assert schema['paths']['/example/']['post']['description'] == 'My Description'
assert schema['paths']['/example/']['post']['requestBody'] == {
'content': {
'application/json': {
'schema': {
'properties': {
'username': {'type': 'string'}
},
'required': ['username']
}
}
}
}