From beb8c665bd2f4a0c0df57ae8422176e39a06f974 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 12 Sep 2019 17:59:50 +0300 Subject: [PATCH] enable schema generation for APIView via decorator --- rest_framework/schemas/openapi.py | 95 +++++++++++++++++++++++++++++-- tests/schemas/test_openapi.py | 36 +++++++++++- 2 files changed, 126 insertions(+), 5 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index ac846bf80..8a715a75e 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -1,4 +1,7 @@ import warnings +from collections import defaultdict +from contextlib import contextmanager +from typing import Optional, Dict, Any, Type from urllib.parse import urljoin from django.core.validators import ( @@ -7,6 +10,7 @@ from django.core.validators import ( ) from django.db import models from django.utils.encoding import force_str +from rest_framework.serializers import Serializer from rest_framework import exceptions, serializers from rest_framework.compat import uritemplate @@ -16,6 +20,7 @@ from .generators import BaseSchemaGenerator from .inspectors import ViewInspector from .utils import get_pk_description, is_list_view + # Generator @@ -73,11 +78,20 @@ class SchemaGenerator(BaseSchemaGenerator): return schema + # View Inspectors +# dictionary of +# label: str # endpoint name +# description: str # docs +# query_params: Serializer # for request query_params +# data: Serializer # for request data +# files: Serializer # for request files +# responses: Dict[int, Serializer] # responses +_MethodData = Dict[str, Any] + class AutoSchema(ViewInspector): - content_types = ['application/json'] method_mapping = { 'get': 'Retrieve', @@ -87,10 +101,43 @@ class AutoSchema(ViewInspector): 'delete': 'Destroy', } + def __init__(self): + super().__init__() + self.apiview_methods = defaultdict(dict) # type: Dict[str, _MethodData] + + def _set_apiview_method_data(self, method: str, + *, + label: Optional[str] = None, + description: Optional[str] = None, + query_params_serializer: Optional[Serializer] = None, + data_serializer: Optional[Serializer] = None, + responses: Optional[Dict[int, Serializer]] = None) -> None: + self.apiview_methods[method] = { + 'label': label, + 'description': description, + 'query_params_serializer': query_params_serializer, + 'data_serializer': data_serializer, + 'responses': responses + } + + @contextmanager + def _replace_get_serializer(self, serializer: Optional[Serializer]) -> None: + if serializer is not None: + old_get_serializer = self._get_serializer + + def always_returns_serializer(method, path): + return serializer + + self._get_serializer = always_returns_serializer + yield + if serializer is not None: + self._get_serializer = old_get_serializer + def get_operation(self, path, method): operation = {} - operation['operationId'] = self._get_operation_id(path, method) + operation['operationId'] = self.apiview_methods[method].get('label') or self._get_operation_id(path, method) + operation['description'] = self.apiview_methods[method].get('description') or '' parameters = [] parameters += self._get_path_parameters(path, method) @@ -98,10 +145,14 @@ class AutoSchema(ViewInspector): parameters += self._get_filter_parameters(path, method) operation['parameters'] = parameters - request_body = self._get_request_body(path, method) + with self._replace_get_serializer(self.apiview_methods[method].get('data_serializer')): + request_body = self._get_request_body(path, method) + if request_body: operation['requestBody'] = request_body - operation['responses'] = self._get_responses(path, method) + + with self._replace_get_serializer((self.apiview_methods[method].get('responses') or {}).get(200)): + operation['responses'] = self._get_responses(path, method) return operation @@ -520,3 +571,39 @@ class AutoSchema(ViewInspector): 'description': "" } } + + +class generate_schema: + def __init__(self, *, + label: Optional[str] = None, + description: Optional[str] = None, + query_params: Optional[Type[Serializer]] = None, + data: Optional[Type[Serializer]] = None, + responses: Optional[Dict[int, Type[Serializer]]] = None): + self.label = label + self.description = description + self.query_params_serializer = query_params + self.data_serializer = data + self.responses = responses + + def __call__(self, method): + _self = self + + class Wrapper: + def __set_name__(self, owner, name): + # hack to bypass __get__ method of the DefaultSchema. + # i don't want to think of or change DefaultSchema internals for now + # no need to really call __get__ again + owner.schema = owner.schema + if isinstance(owner.schema, AutoSchema): + owner.schema._set_apiview_method_data(method.__name__.upper(), + label=_self.label, + description=_self.description or method.__doc__, + query_params_serializer=_self.query_params_serializer, + data_serializer=_self.data_serializer, + responses=_self.responses) + + def __call__(self, request, *args, **kwargs): + return method(self, request, *args, **kwargs) + + return Wrapper() diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index d9375585b..fb5cac71a 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import filters, generics, pagination, routers, serializers from rest_framework.compat import uritemplate from rest_framework.request import Request -from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator +from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator, generate_schema from . import views @@ -570,3 +570,37 @@ class TestGenerator(TestCase): assert schema['info']['title'] == 'My title' assert schema['info']['version'] == '1.2.3' assert schema['info']['description'] == 'My description' + + def test_generate_schema_with_decorator(self): + class MySerializer(serializers.Serializer): + username = serializers.CharField() + + class MyView(views.APIView): + @generate_schema(label='My Label', + description='My Description', + data=MySerializer()) + def post(self, request): + pass + + patterns = [ + url(r'^example/?$', MyView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert schema['paths']['/example/']['post']['operationId'] == 'My Label' + assert schema['paths']['/example/']['post']['description'] == 'My Description' + assert schema['paths']['/example/']['post']['requestBody'] == { + 'content': { + 'application/json': { + 'schema': { + 'properties': { + 'username': {'type': 'string'} + }, + 'required': ['username'] + } + } + } + }