This commit is contained in:
Asif Saif Uddin (Auvi) 2019-11-07 18:12:10 +06:00
commit 49b774e2ad
9 changed files with 235 additions and 74 deletions

View File

@ -94,5 +94,5 @@ As with any other `TemplateResponse`, this method is called to render the serial
You won't typically need to call `.render()` yourself, as it's handled by Django's standard response cycle. You won't typically need to call `.render()` yourself, as it's handled by Django's standard response cycle.
[cite]: https://docs.djangoproject.com/en/stable/stable/template-response/ [cite]: https://docs.djangoproject.com/en/stable/ref/template-response/
[statuscodes]: status-codes.md [statuscodes]: status-codes.md

View File

@ -73,7 +73,7 @@ The `get_schema_view()` helper takes the following keyword arguments:
* `title`: May be used to provide a descriptive title for the schema definition. * `title`: May be used to provide a descriptive title for the schema definition.
* `description`: Longer descriptive text. * `description`: Longer descriptive text.
* `version`: The version of the API. Defaults to `0.1.0`. * `version`: The version of the API.
* `url`: May be used to pass a canonical base URL for the schema. * `url`: May be used to pass a canonical base URL for the schema.
schema_view = get_schema_view( schema_view = get_schema_view(

View File

@ -101,7 +101,7 @@ Default: `'rest_framework.negotiation.DefaultContentNegotiation'`
A view inspector class that will be used for schema generation. A view inspector class that will be used for schema generation.
Default: `'rest_framework.schemas.AutoSchema'` Default: `'rest_framework.schemas.openapi.AutoSchema'`
--- ---

View File

@ -1,26 +1,18 @@
import re
import warnings import warnings
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from urllib import parse from urllib import parse
from django.db import models from django.db import models
from django.utils.encoding import force_str, smart_text from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
def common_path(paths): def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths] split_paths = [path.strip('/').split('/') for path in paths]
@ -397,44 +389,6 @@ class AutoSchema(ViewInspector):
description=description description=description
) )
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method): def get_path_fields(self, path, method):
""" """
Return a list of `coreapi.Field` instances corresponding to any Return a list of `coreapi.Field` instances corresponding to any

View File

@ -151,7 +151,7 @@ class BaseSchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'. # Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=''): def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
if url and not url.endswith('/'): if url and not url.endswith('/'):
url += '/' url += '/'

View File

@ -3,9 +3,13 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
import re
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
from django.utils.encoding import smart_text
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector: class ViewInspector:
@ -15,6 +19,9 @@ class ViewInspector:
Provide subclass for per-view schema generation Provide subclass for per-view schema generation
""" """
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self): def __init__(self):
self.instance_schemas = WeakKeyDictionary() self.instance_schemas = WeakKeyDictionary()
@ -62,6 +69,45 @@ class ViewInspector:
def view(self): def view(self):
self._view = None self._view = None
def get_description(self, path, method):
"""
Determine a path description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -1,4 +1,5 @@
import warnings import warnings
from operator import attrgetter
from urllib.parse import urljoin from urllib.parse import urljoin
from django.core.validators import ( from django.core.validators import (
@ -8,7 +9,7 @@ from django.core.validators import (
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
@ -16,15 +17,14 @@ from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view from .utils import get_pk_description, is_list_view
# Generator
class SchemaGenerator(BaseSchemaGenerator): class SchemaGenerator(BaseSchemaGenerator):
def get_info(self): def get_info(self):
# Title and version are required by openapi specification 3.x
info = { info = {
'title': self.title, 'title': self.title or '',
'version': self.version, 'version': self.version or ''
} }
if self.description is not None: if self.description is not None:
@ -78,7 +78,9 @@ class SchemaGenerator(BaseSchemaGenerator):
class AutoSchema(ViewInspector): class AutoSchema(ViewInspector):
content_types = ['application/json'] request_media_types = []
response_media_types = []
method_mapping = { method_mapping = {
'get': 'Retrieve', 'get': 'Retrieve',
'post': 'Create', 'post': 'Create',
@ -91,6 +93,7 @@ class AutoSchema(ViewInspector):
operation = {} operation = {}
operation['operationId'] = self._get_operation_id(path, method) operation['operationId'] = self._get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = [] parameters = []
parameters += self._get_path_parameters(path, method) parameters += self._get_path_parameters(path, method)
@ -265,9 +268,13 @@ class AutoSchema(ViewInspector):
'items': {}, 'items': {},
} }
if not isinstance(field.child, _UnvalidatedField): if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = { map_field = self._map_field(field.child)
"type": self._map_field(field.child).get('type') items = {
"type": map_field.get('type')
} }
if 'format' in map_field:
items['format'] = map_field.get('format')
mapping['items'] = items
return mapping return mapping
# DateField and DateTimeField type is string # DateField and DateTimeField type is string
@ -337,8 +344,17 @@ class AutoSchema(ViewInspector):
'type': 'integer' 'type': 'integer'
} }
self._map_min_max(field, content) self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
}
# Simplest cases, default to 'string' type: # Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = { FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean', serializers.BooleanField: 'boolean',
@ -434,9 +450,20 @@ class AutoSchema(ViewInspector):
pagination_class = getattr(self.view, 'pagination_class', None) pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class: if pagination_class:
return pagination_class() return pagination_class()
return None return None
def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method):
media_types = []
for renderer in self.view.renderer_classes:
# BrowsableAPIRenderer not relevant to OpenAPI spec
if renderer == renderers.BrowsableAPIRenderer:
continue
media_types.append(renderer.media_type)
return media_types
def _get_serializer(self, method, path): def _get_serializer(self, method, path):
view = self.view view = self.view
@ -456,6 +483,8 @@ class AutoSchema(ViewInspector):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
return {} return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(path, method) serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
@ -473,7 +502,7 @@ class AutoSchema(ViewInspector):
return { return {
'content': { 'content': {
ct: {'schema': content} ct: {'schema': content}
for ct in self.content_types for ct in self.request_media_types
} }
} }
@ -486,6 +515,8 @@ class AutoSchema(ViewInspector):
} }
} }
self.response_media_types = self.map_renderers(path, method)
item_schema = {} item_schema = {}
serializer = self._get_serializer(path, method) serializer = self._get_serializer(path, method)
@ -513,7 +544,7 @@ class AutoSchema(ViewInspector):
'200': { '200': {
'content': { 'content': {
ct: {'schema': response_schema} ct: {'schema': response_schema}
for ct in self.content_types for ct in self.response_media_types
}, },
# description is a mandatory property, # description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject

View File

@ -5,6 +5,8 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import filters, generics, pagination, routers, serializers from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.parsers import JSONParser, MultiPartParser
from rest_framework.renderers import JSONRenderer
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
@ -48,6 +50,10 @@ class TestFieldMapping(TestCase):
(serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}), (serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}),
(serializers.ListField(child=serializers.FloatField()), {'items': {'type': 'number'}, 'type': 'array'}), (serializers.ListField(child=serializers.FloatField()), {'items': {'type': 'number'}, 'type': 'array'}),
(serializers.ListField(child=serializers.CharField()), {'items': {'type': 'string'}, 'type': 'array'}), (serializers.ListField(child=serializers.CharField()), {'items': {'type': 'string'}, 'type': 'array'}),
(serializers.ListField(child=serializers.IntegerField(max_value=4294967295)),
{'items': {'type': 'integer', 'format': 'int64'}, 'type': 'array'}),
(serializers.IntegerField(min_value=2147483648),
{'type': 'integer', 'minimum': 2147483648, 'format': 'int64'}),
] ]
for field, mapping in cases: for field, mapping in cases:
with self.subTest(field=field): with self.subTest(field=field):
@ -71,7 +77,7 @@ class TestOperationIntrospection(TestCase):
method = 'GET' method = 'GET'
view = create_view( view = create_view(
views.ExampleListView, views.DocStringExampleListView,
method, method,
create_request(path) create_request(path)
) )
@ -80,7 +86,8 @@ class TestOperationIntrospection(TestCase):
operation = inspector.get_operation(path, method) operation = inspector.get_operation(path, method)
assert operation == { assert operation == {
'operationId': 'listExamples', 'operationId': 'listDocStringExamples',
'description': 'A description of my GET operation.',
'parameters': [], 'parameters': [],
'responses': { 'responses': {
'200': { '200': {
@ -102,15 +109,18 @@ class TestOperationIntrospection(TestCase):
method = 'GET' method = 'GET'
view = create_view( view = create_view(
views.ExampleDetailView, views.DocStringExampleDetailView,
method, method,
create_request(path) create_request(path)
) )
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
parameters = inspector._get_path_parameters(path, method) operation = inspector.get_operation(path, method)
assert parameters == [{ assert operation == {
'operationId': 'RetrieveDocStringExampleDetail',
'description': 'A description of my GET operation.',
'parameters': [{
'description': '', 'description': '',
'in': 'path', 'in': 'path',
'name': 'id', 'name': 'id',
@ -118,7 +128,19 @@ class TestOperationIntrospection(TestCase):
'schema': { 'schema': {
'type': 'string', 'type': 'string',
}, },
}] }],
'responses': {
'200': {
'description': '',
'content': {
'application/json': {
'schema': {
},
},
},
},
},
}
def test_request_body(self): def test_request_body(self):
path = '/' path = '/'
@ -364,6 +386,77 @@ class TestOperationIntrospection(TestCase):
}, },
} }
def test_parser_mapping(self):
"""Test that view's parsers are mapped to OA media types"""
path = '/{id}/'
method = 'POST'
class View(generics.CreateAPIView):
serializer_class = views.ExampleSerializer
parser_classes = [JSONParser, MultiPartParser]
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
assert len(request_body['content'].keys()) == 2
assert 'multipart/form-data' in request_body['content']
assert 'application/json' in request_body['content']
def test_renderer_mapping(self):
"""Test that view's renderers are mapped to OA media types"""
path = '/{id}/'
method = 'GET'
class View(generics.CreateAPIView):
serializer_class = views.ExampleSerializer
renderer_classes = [JSONRenderer]
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
# TODO this should be changed once the multiple response
# schema support is there
success_response = responses['200']
assert len(success_response['content'].keys()) == 1
assert 'application/json' in success_response['content']
def test_serializer_filefield(self):
path = '/{id}/'
method = 'POST'
class ItemSerializer(serializers.Serializer):
attachment = serializers.FileField()
class View(generics.CreateAPIView):
serializer_class = ItemSerializer
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
mp_media = request_body['content']['multipart/form-data']
attachment = mp_media['schema']['properties']['attachment']
assert attachment['format'] == 'binary'
def test_retrieve_response_body_generation(self): def test_retrieve_response_body_generation(self):
""" """
Test that a list of properties is returned for retrieve item views. Test that a list of properties is returned for retrieve item views.
@ -611,3 +704,16 @@ class TestGenerator(TestCase):
assert schema['info']['title'] == 'My title' assert schema['info']['title'] == 'My title'
assert schema['info']['version'] == '1.2.3' assert schema['info']['version'] == '1.2.3'
assert schema['info']['description'] == 'My description' assert schema['info']['description'] == 'My description'
def test_schema_information_empty(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert schema['info']['title'] == ''
assert schema['info']['version'] == ''

View File

@ -29,6 +29,30 @@ class ExampleDetailView(APIView):
pass pass
class DocStringExampleListView(APIView):
"""
get: A description of my GET operation.
post: A description of my POST operation.
"""
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class DocStringExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
"""
A description of my GET operation.
"""
pass
# Generics. # Generics.
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
date = serializers.DateField() date = serializers.DateField()