mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-28 17:09:59 +03:00
enable schema generation for APIView via decorator
This commit is contained in:
parent
89ac0a1c7e
commit
beb8c665bd
|
@ -1,4 +1,7 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Any, Type
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from django.core.validators import (
|
||||
|
@ -7,6 +10,7 @@ from django.core.validators import (
|
|||
)
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import uritemplate
|
||||
|
@ -16,6 +20,7 @@ from .generators import BaseSchemaGenerator
|
|||
from .inspectors import ViewInspector
|
||||
from .utils import get_pk_description, is_list_view
|
||||
|
||||
|
||||
# Generator
|
||||
|
||||
|
||||
|
@ -73,11 +78,20 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
|
||||
return schema
|
||||
|
||||
|
||||
# View Inspectors
|
||||
|
||||
# dictionary of
|
||||
# label: str # endpoint name
|
||||
# description: str # docs
|
||||
# query_params: Serializer # for request query_params
|
||||
# data: Serializer # for request data
|
||||
# files: Serializer # for request files
|
||||
# responses: Dict[int, Serializer] # responses
|
||||
_MethodData = Dict[str, Any]
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
|
||||
content_types = ['application/json']
|
||||
method_mapping = {
|
||||
'get': 'Retrieve',
|
||||
|
@ -87,10 +101,43 @@ class AutoSchema(ViewInspector):
|
|||
'delete': 'Destroy',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.apiview_methods = defaultdict(dict) # type: Dict[str, _MethodData]
|
||||
|
||||
def _set_apiview_method_data(self, method: str,
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
query_params_serializer: Optional[Serializer] = None,
|
||||
data_serializer: Optional[Serializer] = None,
|
||||
responses: Optional[Dict[int, Serializer]] = None) -> None:
|
||||
self.apiview_methods[method] = {
|
||||
'label': label,
|
||||
'description': description,
|
||||
'query_params_serializer': query_params_serializer,
|
||||
'data_serializer': data_serializer,
|
||||
'responses': responses
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def _replace_get_serializer(self, serializer: Optional[Serializer]) -> None:
|
||||
if serializer is not None:
|
||||
old_get_serializer = self._get_serializer
|
||||
|
||||
def always_returns_serializer(method, path):
|
||||
return serializer
|
||||
|
||||
self._get_serializer = always_returns_serializer
|
||||
yield
|
||||
if serializer is not None:
|
||||
self._get_serializer = old_get_serializer
|
||||
|
||||
def get_operation(self, path, method):
|
||||
operation = {}
|
||||
|
||||
operation['operationId'] = self._get_operation_id(path, method)
|
||||
operation['operationId'] = self.apiview_methods[method].get('label') or self._get_operation_id(path, method)
|
||||
operation['description'] = self.apiview_methods[method].get('description') or ''
|
||||
|
||||
parameters = []
|
||||
parameters += self._get_path_parameters(path, method)
|
||||
|
@ -98,10 +145,14 @@ class AutoSchema(ViewInspector):
|
|||
parameters += self._get_filter_parameters(path, method)
|
||||
operation['parameters'] = parameters
|
||||
|
||||
request_body = self._get_request_body(path, method)
|
||||
with self._replace_get_serializer(self.apiview_methods[method].get('data_serializer')):
|
||||
request_body = self._get_request_body(path, method)
|
||||
|
||||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
|
||||
with self._replace_get_serializer((self.apiview_methods[method].get('responses') or {}).get(200)):
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
|
@ -520,3 +571,39 @@ class AutoSchema(ViewInspector):
|
|||
'description': ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class generate_schema:
|
||||
def __init__(self, *,
|
||||
label: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
query_params: Optional[Type[Serializer]] = None,
|
||||
data: Optional[Type[Serializer]] = None,
|
||||
responses: Optional[Dict[int, Type[Serializer]]] = None):
|
||||
self.label = label
|
||||
self.description = description
|
||||
self.query_params_serializer = query_params
|
||||
self.data_serializer = data
|
||||
self.responses = responses
|
||||
|
||||
def __call__(self, method):
|
||||
_self = self
|
||||
|
||||
class Wrapper:
|
||||
def __set_name__(self, owner, name):
|
||||
# hack to bypass __get__ method of the DefaultSchema.
|
||||
# i don't want to think of or change DefaultSchema internals for now
|
||||
# no need to really call __get__ again
|
||||
owner.schema = owner.schema
|
||||
if isinstance(owner.schema, AutoSchema):
|
||||
owner.schema._set_apiview_method_data(method.__name__.upper(),
|
||||
label=_self.label,
|
||||
description=_self.description or method.__doc__,
|
||||
query_params_serializer=_self.query_params_serializer,
|
||||
data_serializer=_self.data_serializer,
|
||||
responses=_self.responses)
|
||||
|
||||
def __call__(self, request, *args, **kwargs):
|
||||
return method(self, request, *args, **kwargs)
|
||||
|
||||
return Wrapper()
|
||||
|
|
|
@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import filters, generics, pagination, routers, serializers
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
|
||||
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator, generate_schema
|
||||
|
||||
from . import views
|
||||
|
||||
|
@ -570,3 +570,37 @@ class TestGenerator(TestCase):
|
|||
assert schema['info']['title'] == 'My title'
|
||||
assert schema['info']['version'] == '1.2.3'
|
||||
assert schema['info']['description'] == 'My description'
|
||||
|
||||
def test_generate_schema_with_decorator(self):
|
||||
class MySerializer(serializers.Serializer):
|
||||
username = serializers.CharField()
|
||||
|
||||
class MyView(views.APIView):
|
||||
@generate_schema(label='My Label',
|
||||
description='My Description',
|
||||
data=MySerializer())
|
||||
def post(self, request):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url(r'^example/?$', MyView.as_view()),
|
||||
]
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
|
||||
assert schema['paths']['/example/']['post']['operationId'] == 'My Label'
|
||||
assert schema['paths']['/example/']['post']['description'] == 'My Description'
|
||||
assert schema['paths']['/example/']['post']['requestBody'] == {
|
||||
'content': {
|
||||
'application/json': {
|
||||
'schema': {
|
||||
'properties': {
|
||||
'username': {'type': 'string'}
|
||||
},
|
||||
'required': ['username']
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user