mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-29 17:39:48 +03:00
enable schema generation for APIView via decorator
This commit is contained in:
parent
89ac0a1c7e
commit
beb8c665bd
|
@ -1,4 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Dict, Any, Type
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from django.core.validators import (
|
from django.core.validators import (
|
||||||
|
@ -7,6 +10,7 @@ from django.core.validators import (
|
||||||
)
|
)
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
from rest_framework.serializers import Serializer
|
||||||
|
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
|
@ -16,6 +20,7 @@ from .generators import BaseSchemaGenerator
|
||||||
from .inspectors import ViewInspector
|
from .inspectors import ViewInspector
|
||||||
from .utils import get_pk_description, is_list_view
|
from .utils import get_pk_description, is_list_view
|
||||||
|
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,11 +78,20 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
# View Inspectors
|
# View Inspectors
|
||||||
|
|
||||||
|
# dictionary of
|
||||||
|
# label: str # endpoint name
|
||||||
|
# description: str # docs
|
||||||
|
# query_params: Serializer # for request query_params
|
||||||
|
# data: Serializer # for request data
|
||||||
|
# files: Serializer # for request files
|
||||||
|
# responses: Dict[int, Serializer] # responses
|
||||||
|
_MethodData = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class AutoSchema(ViewInspector):
|
class AutoSchema(ViewInspector):
|
||||||
|
|
||||||
content_types = ['application/json']
|
content_types = ['application/json']
|
||||||
method_mapping = {
|
method_mapping = {
|
||||||
'get': 'Retrieve',
|
'get': 'Retrieve',
|
||||||
|
@ -87,10 +101,43 @@ class AutoSchema(ViewInspector):
|
||||||
'delete': 'Destroy',
|
'delete': 'Destroy',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.apiview_methods = defaultdict(dict) # type: Dict[str, _MethodData]
|
||||||
|
|
||||||
|
def _set_apiview_method_data(self, method: str,
|
||||||
|
*,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
query_params_serializer: Optional[Serializer] = None,
|
||||||
|
data_serializer: Optional[Serializer] = None,
|
||||||
|
responses: Optional[Dict[int, Serializer]] = None) -> None:
|
||||||
|
self.apiview_methods[method] = {
|
||||||
|
'label': label,
|
||||||
|
'description': description,
|
||||||
|
'query_params_serializer': query_params_serializer,
|
||||||
|
'data_serializer': data_serializer,
|
||||||
|
'responses': responses
|
||||||
|
}
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _replace_get_serializer(self, serializer: Optional[Serializer]) -> None:
|
||||||
|
if serializer is not None:
|
||||||
|
old_get_serializer = self._get_serializer
|
||||||
|
|
||||||
|
def always_returns_serializer(method, path):
|
||||||
|
return serializer
|
||||||
|
|
||||||
|
self._get_serializer = always_returns_serializer
|
||||||
|
yield
|
||||||
|
if serializer is not None:
|
||||||
|
self._get_serializer = old_get_serializer
|
||||||
|
|
||||||
def get_operation(self, path, method):
|
def get_operation(self, path, method):
|
||||||
operation = {}
|
operation = {}
|
||||||
|
|
||||||
operation['operationId'] = self._get_operation_id(path, method)
|
operation['operationId'] = self.apiview_methods[method].get('label') or self._get_operation_id(path, method)
|
||||||
|
operation['description'] = self.apiview_methods[method].get('description') or ''
|
||||||
|
|
||||||
parameters = []
|
parameters = []
|
||||||
parameters += self._get_path_parameters(path, method)
|
parameters += self._get_path_parameters(path, method)
|
||||||
|
@ -98,10 +145,14 @@ class AutoSchema(ViewInspector):
|
||||||
parameters += self._get_filter_parameters(path, method)
|
parameters += self._get_filter_parameters(path, method)
|
||||||
operation['parameters'] = parameters
|
operation['parameters'] = parameters
|
||||||
|
|
||||||
request_body = self._get_request_body(path, method)
|
with self._replace_get_serializer(self.apiview_methods[method].get('data_serializer')):
|
||||||
|
request_body = self._get_request_body(path, method)
|
||||||
|
|
||||||
if request_body:
|
if request_body:
|
||||||
operation['requestBody'] = request_body
|
operation['requestBody'] = request_body
|
||||||
operation['responses'] = self._get_responses(path, method)
|
|
||||||
|
with self._replace_get_serializer((self.apiview_methods[method].get('responses') or {}).get(200)):
|
||||||
|
operation['responses'] = self._get_responses(path, method)
|
||||||
|
|
||||||
return operation
|
return operation
|
||||||
|
|
||||||
|
@ -520,3 +571,39 @@ class AutoSchema(ViewInspector):
|
||||||
'description': ""
|
'description': ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class generate_schema:
|
||||||
|
def __init__(self, *,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
query_params: Optional[Type[Serializer]] = None,
|
||||||
|
data: Optional[Type[Serializer]] = None,
|
||||||
|
responses: Optional[Dict[int, Type[Serializer]]] = None):
|
||||||
|
self.label = label
|
||||||
|
self.description = description
|
||||||
|
self.query_params_serializer = query_params
|
||||||
|
self.data_serializer = data
|
||||||
|
self.responses = responses
|
||||||
|
|
||||||
|
def __call__(self, method):
|
||||||
|
_self = self
|
||||||
|
|
||||||
|
class Wrapper:
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
# hack to bypass __get__ method of the DefaultSchema.
|
||||||
|
# i don't want to think of or change DefaultSchema internals for now
|
||||||
|
# no need to really call __get__ again
|
||||||
|
owner.schema = owner.schema
|
||||||
|
if isinstance(owner.schema, AutoSchema):
|
||||||
|
owner.schema._set_apiview_method_data(method.__name__.upper(),
|
||||||
|
label=_self.label,
|
||||||
|
description=_self.description or method.__doc__,
|
||||||
|
query_params_serializer=_self.query_params_serializer,
|
||||||
|
data_serializer=_self.data_serializer,
|
||||||
|
responses=_self.responses)
|
||||||
|
|
||||||
|
def __call__(self, request, *args, **kwargs):
|
||||||
|
return method(self, request, *args, **kwargs)
|
||||||
|
|
||||||
|
return Wrapper()
|
||||||
|
|
|
@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework import filters, generics, pagination, routers, serializers
|
from rest_framework import filters, generics, pagination, routers, serializers
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
|
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator, generate_schema
|
||||||
|
|
||||||
from . import views
|
from . import views
|
||||||
|
|
||||||
|
@ -570,3 +570,37 @@ class TestGenerator(TestCase):
|
||||||
assert schema['info']['title'] == 'My title'
|
assert schema['info']['title'] == 'My title'
|
||||||
assert schema['info']['version'] == '1.2.3'
|
assert schema['info']['version'] == '1.2.3'
|
||||||
assert schema['info']['description'] == 'My description'
|
assert schema['info']['description'] == 'My description'
|
||||||
|
|
||||||
|
def test_generate_schema_with_decorator(self):
|
||||||
|
class MySerializer(serializers.Serializer):
|
||||||
|
username = serializers.CharField()
|
||||||
|
|
||||||
|
class MyView(views.APIView):
|
||||||
|
@generate_schema(label='My Label',
|
||||||
|
description='My Description',
|
||||||
|
data=MySerializer())
|
||||||
|
def post(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
url(r'^example/?$', MyView.as_view()),
|
||||||
|
]
|
||||||
|
generator = SchemaGenerator(patterns=patterns)
|
||||||
|
|
||||||
|
request = create_request('/')
|
||||||
|
schema = generator.get_schema(request=request)
|
||||||
|
|
||||||
|
assert schema['paths']['/example/']['post']['operationId'] == 'My Label'
|
||||||
|
assert schema['paths']['/example/']['post']['description'] == 'My Description'
|
||||||
|
assert schema['paths']['/example/']['post']['requestBody'] == {
|
||||||
|
'content': {
|
||||||
|
'application/json': {
|
||||||
|
'schema': {
|
||||||
|
'properties': {
|
||||||
|
'username': {'type': 'string'}
|
||||||
|
},
|
||||||
|
'required': ['username']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user