This commit is contained in:
Asif Saif Uddin (Auvi) 2019-11-07 18:12:10 +06:00
commit 49b774e2ad
9 changed files with 235 additions and 74 deletions

View File

@ -94,5 +94,5 @@ As with any other `TemplateResponse`, this method is called to render the serial
You won't typically need to call `.render()` yourself, as it's handled by Django's standard response cycle.
[cite]: https://docs.djangoproject.com/en/stable/stable/template-response/
[cite]: https://docs.djangoproject.com/en/stable/ref/template-response/
[statuscodes]: status-codes.md

View File

@ -73,7 +73,7 @@ The `get_schema_view()` helper takes the following keyword arguments:
* `title`: May be used to provide a descriptive title for the schema definition.
* `description`: Longer descriptive text.
* `version`: The version of the API. Defaults to `0.1.0`.
* `version`: The version of the API.
* `url`: May be used to pass a canonical base URL for the schema.
schema_view = get_schema_view(

View File

@ -101,7 +101,7 @@ Default: `'rest_framework.negotiation.DefaultContentNegotiation'`
A view inspector class that will be used for schema generation.
Default: `'rest_framework.schemas.AutoSchema'`
Default: `'rest_framework.schemas.openapi.AutoSchema'`
---

View File

@ -1,26 +1,18 @@
import re
import warnings
from collections import Counter, OrderedDict
from urllib import parse
from django.db import models
from django.utils.encoding import force_str, smart_text
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
@ -397,44 +389,6 @@ class AutoSchema(ViewInspector):
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any

View File

@ -151,7 +151,7 @@ class BaseSchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=''):
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
if url and not url.endswith('/'):
url += '/'

View File

@ -3,9 +3,13 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
from weakref import WeakKeyDictionary
from django.utils.encoding import smart_text
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector:
@ -15,6 +19,9 @@ class ViewInspector:
Provide subclass for per-view schema generation
"""
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self):
self.instance_schemas = WeakKeyDictionary()
@ -62,6 +69,45 @@ class ViewInspector:
def view(self):
self._view = None
def get_description(self, path, method):
"""
Determine a path description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -1,4 +1,5 @@
import warnings
from operator import attrgetter
from urllib.parse import urljoin
from django.core.validators import (
@ -8,7 +9,7 @@ from django.core.validators import (
from django.db import models
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
@ -16,15 +17,14 @@ from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Generator
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
# Title and version are required by openapi specification 3.x
info = {
'title': self.title,
'version': self.version,
'title': self.title or '',
'version': self.version or ''
}
if self.description is not None:
@ -78,7 +78,9 @@ class SchemaGenerator(BaseSchemaGenerator):
class AutoSchema(ViewInspector):
content_types = ['application/json']
request_media_types = []
response_media_types = []
method_mapping = {
'get': 'Retrieve',
'post': 'Create',
@ -91,6 +93,7 @@ class AutoSchema(ViewInspector):
operation = {}
operation['operationId'] = self._get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = []
parameters += self._get_path_parameters(path, method)
@ -265,9 +268,13 @@ class AutoSchema(ViewInspector):
'items': {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = {
"type": self._map_field(field.child).get('type')
map_field = self._map_field(field.child)
items = {
"type": map_field.get('type')
}
if 'format' in map_field:
items['format'] = map_field.get('format')
mapping['items'] = items
return mapping
# DateField and DateTimeField type is string
@ -337,8 +344,17 @@ class AutoSchema(ViewInspector):
'type': 'integer'
}
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
@ -434,9 +450,20 @@ class AutoSchema(ViewInspector):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
return pagination_class()
return None
def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method):
media_types = []
for renderer in self.view.renderer_classes:
# BrowsableAPIRenderer not relevant to OpenAPI spec
if renderer == renderers.BrowsableAPIRenderer:
continue
media_types.append(renderer.media_type)
return media_types
def _get_serializer(self, method, path):
view = self.view
@ -456,6 +483,8 @@ class AutoSchema(ViewInspector):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
@ -473,7 +502,7 @@ class AutoSchema(ViewInspector):
return {
'content': {
ct: {'schema': content}
for ct in self.content_types
for ct in self.request_media_types
}
}
@ -486,6 +515,8 @@ class AutoSchema(ViewInspector):
}
}
self.response_media_types = self.map_renderers(path, method)
item_schema = {}
serializer = self._get_serializer(path, method)
@ -513,7 +544,7 @@ class AutoSchema(ViewInspector):
'200': {
'content': {
ct: {'schema': response_schema}
for ct in self.content_types
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject

View File

@ -5,6 +5,8 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.compat import uritemplate
from rest_framework.parsers import JSONParser, MultiPartParser
from rest_framework.renderers import JSONRenderer
from rest_framework.request import Request
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
@ -48,6 +50,10 @@ class TestFieldMapping(TestCase):
(serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}),
(serializers.ListField(child=serializers.FloatField()), {'items': {'type': 'number'}, 'type': 'array'}),
(serializers.ListField(child=serializers.CharField()), {'items': {'type': 'string'}, 'type': 'array'}),
(serializers.ListField(child=serializers.IntegerField(max_value=4294967295)),
{'items': {'type': 'integer', 'format': 'int64'}, 'type': 'array'}),
(serializers.IntegerField(min_value=2147483648),
{'type': 'integer', 'minimum': 2147483648, 'format': 'int64'}),
]
for field, mapping in cases:
with self.subTest(field=field):
@ -71,7 +77,7 @@ class TestOperationIntrospection(TestCase):
method = 'GET'
view = create_view(
views.ExampleListView,
views.DocStringExampleListView,
method,
create_request(path)
)
@ -80,7 +86,8 @@ class TestOperationIntrospection(TestCase):
operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'listExamples',
'operationId': 'listDocStringExamples',
'description': 'A description of my GET operation.',
'parameters': [],
'responses': {
'200': {
@ -102,23 +109,38 @@ class TestOperationIntrospection(TestCase):
method = 'GET'
view = create_view(
views.ExampleDetailView,
views.DocStringExampleDetailView,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
parameters = inspector._get_path_parameters(path, method)
assert parameters == [{
'description': '',
'in': 'path',
'name': 'id',
'required': True,
'schema': {
'type': 'string',
operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'RetrieveDocStringExampleDetail',
'description': 'A description of my GET operation.',
'parameters': [{
'description': '',
'in': 'path',
'name': 'id',
'required': True,
'schema': {
'type': 'string',
},
}],
'responses': {
'200': {
'description': '',
'content': {
'application/json': {
'schema': {
},
},
},
},
},
}]
}
def test_request_body(self):
path = '/'
@ -364,6 +386,77 @@ class TestOperationIntrospection(TestCase):
},
}
def test_parser_mapping(self):
"""Test that view's parsers are mapped to OA media types"""
path = '/{id}/'
method = 'POST'
class View(generics.CreateAPIView):
serializer_class = views.ExampleSerializer
parser_classes = [JSONParser, MultiPartParser]
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
assert len(request_body['content'].keys()) == 2
assert 'multipart/form-data' in request_body['content']
assert 'application/json' in request_body['content']
def test_renderer_mapping(self):
"""Test that view's renderers are mapped to OA media types"""
path = '/{id}/'
method = 'GET'
class View(generics.CreateAPIView):
serializer_class = views.ExampleSerializer
renderer_classes = [JSONRenderer]
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
# TODO this should be changed once the multiple response
# schema support is there
success_response = responses['200']
assert len(success_response['content'].keys()) == 1
assert 'application/json' in success_response['content']
def test_serializer_filefield(self):
path = '/{id}/'
method = 'POST'
class ItemSerializer(serializers.Serializer):
attachment = serializers.FileField()
class View(generics.CreateAPIView):
serializer_class = ItemSerializer
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
mp_media = request_body['content']['multipart/form-data']
attachment = mp_media['schema']['properties']['attachment']
assert attachment['format'] == 'binary'
def test_retrieve_response_body_generation(self):
"""
Test that a list of properties is returned for retrieve item views.
@ -611,3 +704,16 @@ class TestGenerator(TestCase):
assert schema['info']['title'] == 'My title'
assert schema['info']['version'] == '1.2.3'
assert schema['info']['description'] == 'My description'
def test_schema_information_empty(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert schema['info']['title'] == ''
assert schema['info']['version'] == ''

View File

@ -29,6 +29,30 @@ class ExampleDetailView(APIView):
pass
class DocStringExampleListView(APIView):
"""
get: A description of my GET operation.
post: A description of my POST operation.
"""
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class DocStringExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
"""
A description of my GET operation.
"""
pass
# Generics.
class ExampleSerializer(serializers.Serializer):
date = serializers.DateField()