This commit is contained in:
Carlton Gibson 2018-10-03 12:03:35 +00:00 committed by GitHub
commit 1903c3f41b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 663 additions and 449 deletions

View File

View File

@ -0,0 +1,46 @@
from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi
from rest_framework.renderers import CoreJSONRenderer, OpenAPIRenderer
from rest_framework.settings import api_settings
class Command(BaseCommand):
help = "Generates configured API schema for project."
def add_arguments(self, parser):
# TODO
# SchemaGenerator allows passing:
#
# - title
# - url
# - description
# - urlconf
# - patterns
#
# Don't particularly want to pass these on the command-line.
# conf file?
#
# Other options to consider:
# - indent
# - ...
pass
def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.'
generator_class = api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS()
generator = generator_class()
schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer('openapi')
output = renderer.render(schema)
self.stdout.write(output)
def get_renderer(self, format):
return {
'corejson': CoreJSONRenderer(),
'openapi': OpenAPIRenderer()
}

View File

@ -9,6 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API.
from __future__ import unicode_literals
import base64
import urllib.parse as urlparse
from collections import OrderedDict
from django import forms
@ -24,7 +25,7 @@ from django.utils.html import mark_safe
from rest_framework import VERSION, exceptions, serializers, status
from rest_framework.compat import (
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema,
pygments_css
)
from rest_framework.exceptions import ParseError
@ -932,3 +933,95 @@ class CoreJSONRenderer(BaseRenderer):
indent = bool(renderer_context.get('indent', 0))
codec = coreapi.codecs.CoreJSONCodec()
return codec.dump(data, indent=indent)
class OpenAPIRenderer:
CLASS_TO_TYPENAME = {
coreschema.Object: 'object',
coreschema.Array: 'array',
coreschema.Number: 'number',
coreschema.Integer: 'integer',
coreschema.String: 'string',
coreschema.Boolean: 'boolean',
}
def __init__(self):
assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.'
def get_schema(self, instance):
schema = {}
if instance.__class__ in self.CLASS_TO_TYPENAME:
schema['type'] = self.CLASS_TO_TYPENAME[instance.__class__]
schema['title'] = instance.title,
schema['description'] = instance.description
if hasattr(instance, 'enum'):
schema['enum'] = instance.enum
return schema
def get_parameters(self, link):
parameters = []
for field in link.fields:
if field.location not in ['path', 'query']:
continue
parameter = {
'name': field.name,
'in': field.location,
}
if field.required:
parameter['required'] = True
if field.description:
parameter['description'] = field.description
if field.schema:
parameter['schema'] = self.get_schema(field.schema)
parameters.append(parameter)
return parameters
def get_operation(self, link, name, tag):
operation_id = "%s_%s" % (tag, name) if tag else name
parameters = self.get_parameters(link)
operation = {
'operationId': operation_id,
}
if link.title:
operation['summary'] = link.title
if link.description:
operation['description'] = link.description
if parameters:
operation['parameters'] = parameters
if tag:
operation['tags'] = [tag]
return operation
def get_paths(self, document):
paths = {}
tag = None
for name, link in document.links.items():
path = urlparse.urlparse(link.url).path
method = link.action.lower()
paths.setdefault(path, {})
paths[path][method] = self.get_operation(link, name, tag=tag)
for tag, section in document.data.items():
for name, link in section.links.items():
path = urlparse.urlparse(link.url).path
method = link.action.lower()
paths.setdefault(path, {})
paths[path][method] = self.get_operation(link, name, tag=tag)
return paths
def render(self, data, media_type=None, renderer_context=None):
return json.dumps({
'openapi': '3.0.0',
'info': {
'version': '',
'title': data.title,
'description': data.description
},
'servers': [{
'url': data.url
}],
'paths': self.get_paths(data)
}, indent=4)

View File

@ -241,35 +241,18 @@ class EndpointEnumerator(object):
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class SchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
class BaseSchemaGenerator(object):
endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'):
url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
@ -279,36 +262,15 @@ class SchemaGenerator(object):
self.url = url
self.endpoints = None
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
def _initialise_endpoints(self):
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
def _get_paths_and_endpoints(self, request):
"""
Return a dictionary containing all the links that should be
included in the API schema.
Generate (path, method, view) given (path, method, callback) for paths.
"""
links = LinkNode()
# Generate (path, method, view) given (path, method, callback).
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
@ -317,22 +279,48 @@ class SchemaGenerator(object):
paths.append(path)
view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
return paths, view_endpoints
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
return links
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
# Methods used when we generate a view instance from the raw callback...
if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def determine_path_prefix(self, paths):
"""
@ -365,29 +353,6 @@ class SchemaGenerator(object):
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
@ -401,23 +366,77 @@ class SchemaGenerator(object):
return False
return True
def coerce_path(self, path, method, view):
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
Return a dictionary containing all the links that should be
included in the API schema.
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
links = LinkNode()
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within

View File

@ -174,20 +174,6 @@ class ViewInspector(object):
def view(self):
self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector):
"""
@ -208,6 +194,17 @@ class AutoSchema(ViewInspector):
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
@ -501,3 +498,44 @@ class DefaultSchema(ViewInspector):
inspector = inspector_class()
inspector.view = instance
return inspector
class OpenAPIAutoSchema(ViewInspector):
def get_operation(self, path, method):
return {
'parameters': self.get_path_parameters(path, method),
}
def get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
}
parameters.append(parameter)
return parameters

View File

@ -57,6 +57,7 @@ DEFAULTS = {
# Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
'DEFAULT_SCHEMA_GENERATOR_CLASS': 'rest_framework.schemas.generators.SchemaGenerator',
# Throttling
'DEFAULT_THROTTLE_RATES': {
@ -144,6 +145,7 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_CLASS',
'DEFAULT_FILTER_BACKENDS',
'DEFAULT_SCHEMA_CLASS',
'DEFAULT_SCHEMA_GENERATOR_CLASS',
'EXCEPTION_HANDLER',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',

View File

View File

@ -2,15 +2,11 @@ import unittest
import pytest
from django.conf.urls import include, url
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.test import TestCase, override_settings
from rest_framework import (
filters, generics, pagination, permissions, serializers
)
from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path
from rest_framework.decorators import action, api_view, schema
from rest_framework import filters, generics, serializers
from rest_framework.compat import coreapi, coreschema, path
from rest_framework.decorators import action, api_view
from rest_framework.request import Request
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import (
@ -24,7 +20,8 @@ from rest_framework.utils import formatting
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import BasicModel, ForeignKeySource
from . import views
from ..models import BasicModel, ForeignKeySource
factory = APIRequestFactory()
@ -34,87 +31,6 @@ class MockUser(object):
return True
class ExamplePagination(pagination.PageNumberPagination):
page_size = 100
page_size_query_param = 'page_size'
class EmptySerializer(serializers.Serializer):
pass
class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True, help_text='A field description')
b = serializers.CharField(required=False)
read_only = serializers.CharField(read_only=True)
hidden = serializers.HiddenField(default='hello')
class AnotherSerializerWithDictField(serializers.Serializer):
a = serializers.DictField()
class AnotherSerializerWithListFields(serializers.Serializer):
a = serializers.ListField(child=serializers.IntegerField())
b = serializers.ListSerializer(child=serializers.CharField())
class AnotherSerializer(serializers.Serializer):
c = serializers.CharField(required=True)
d = serializers.CharField(required=False)
class ExampleViewSet(ModelViewSet):
pagination_class = ExamplePagination
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [filters.OrderingFilter]
serializer_class = ExampleSerializer
@action(methods=['post'], detail=True, serializer_class=AnotherSerializer)
def custom_action(self, request, pk):
"""
A description of custom action.
"""
raise NotImplementedError
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
def custom_action_with_dict_field(self, request, pk):
"""
A custom action using a dict field in the serializer.
"""
raise NotImplementedError
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
def custom_action_with_list_fields(self, request, pk):
"""
A custom action using both list field and list serializer in the serializer.
"""
raise NotImplementedError
@action(detail=False)
def custom_list_action(self, request):
raise NotImplementedError
@action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
"""Custom description."""
raise NotImplementedError
@custom_list_action_multiple_methods.mapping.delete
def custom_list_action_multiple_methods_delete(self, request):
"""Deletion description."""
raise NotImplementedError
@action(detail=False, schema=None)
def excluded_action(self, request):
pass
def get_serializer(self, *args, **kwargs):
assert self.request
assert self.action
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
if coreapi:
schema_view = get_schema_view(title='Example API')
else:
@ -122,7 +38,7 @@ else:
pass
router = DefaultRouter()
router.register('example', ExampleViewSet, basename='example')
router.register('example', views.ExampleViewSet, basename='example')
urlpatterns = [
url(r'^$', schema_view),
url(r'^', include(router.urls))
@ -130,7 +46,7 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas')
@override_settings(ROOT_URLCONF='tests.schemas.test_coreapi')
class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self):
client = APIClient()
@ -299,61 +215,13 @@ class TestRouterGeneratedSchema(TestCase):
assert response.data == expected
class DenyAllUsingHttp404(permissions.BasePermission):
def has_permission(self, request, view):
raise Http404()
def has_object_permission(self, request, view, obj):
raise Http404()
class DenyAllUsingPermissionDenied(permissions.BasePermission):
def has_permission(self, request, view):
raise PermissionDenied()
def has_object_permission(self, request, view, obj):
raise PermissionDenied()
class Http404ExampleViewSet(ExampleViewSet):
permission_classes = [DenyAllUsingHttp404]
class PermissionDeniedExampleViewSet(ExampleViewSet):
permission_classes = [DenyAllUsingPermissionDenied]
class MethodLimitedViewSet(ExampleViewSet):
permission_classes = []
http_method_names = ['get', 'head', 'options']
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaGenerator(TestCase):
def setUp(self):
self.patterns = [
url(r'^example/?$', ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
url(r'^example/?$', views.ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -404,9 +272,9 @@ class TestSchemaGenerator(TestCase):
class TestSchemaGeneratorDjango2(TestCase):
def setUp(self):
self.patterns = [
path('example/', ExampleListView.as_view()),
path('example/<int:pk>/', ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', ExampleDetailView.as_view()),
path('example/', views.ExampleListView.as_view()),
path('example/<int:pk>/', views.ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -456,9 +324,9 @@ class TestSchemaGeneratorDjango2(TestCase):
class TestSchemaGeneratorNotAtRoot(TestCase):
def setUp(self):
self.patterns = [
url(r'^api/v1/example/?$', ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@ -509,7 +377,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
router.register('example1', MethodLimitedViewSet, basename='example1')
router.register('example1', views.MethodLimitedViewSet, basename='example1')
self.patterns = [
url(r'^', include(router.urls))
]
@ -566,10 +434,10 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
router.register('example1', Http404ExampleViewSet, basename='example1')
router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
router.register('example1', views.Http404ExampleViewSet, basename='example1')
router.register('example2', views.PermissionDeniedExampleViewSet, basename='example2')
self.patterns = [
url('^example/?$', ExampleListView.as_view()),
url('^example/?$', views.ExampleListView.as_view()),
url(r'^', include(router.urls))
]
@ -597,29 +465,25 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
assert schema == expected
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
fields = ('id', 'name', 'target')
class ForeignKeySourceView(generics.CreateAPIView):
queryset = ForeignKeySource.objects.all()
serializer_class = ForeignKeySourceSerializer
@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaGeneratorWithForeignKey(TestCase):
def setUp(self):
self.patterns = [
url(r'^example/?$', ForeignKeySourceView.as_view()),
]
def test_schema_for_regular_views(self):
"""
Ensure that AutoField foreign keys are output as Integer.
"""
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
fields = ('id', 'name', 'target')
class ForeignKeySourceView(generics.CreateAPIView):
queryset = ForeignKeySource.objects.all()
serializer_class = ForeignKeySourceSerializer
patterns = [
url(r'^example/?$', ForeignKeySourceView.as_view()),
]
generator = SchemaGenerator(title='Example API', patterns=patterns)
schema = generator.get_schema()
expected = coreapi.Document(
@ -653,35 +517,8 @@ class Test4605Regression(TestCase):
assert prefix == '/'
class CustomViewInspector(AutoSchema):
"""A dummy AutoSchema subclass"""
pass
class TestAutoSchema(TestCase):
def test_apiview_schema_descriptor(self):
view = APIView()
assert hasattr(view, 'schema')
assert isinstance(view.schema, AutoSchema)
def test_set_custom_inspector_class_on_view(self):
class CustomView(APIView):
schema = CustomViewInspector()
view = CustomView()
assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
view = APIView()
assert isinstance(view.schema, CustomViewInspector)
def test_get_link_requires_instance(self):
descriptor = APIView.schema # Accessed from class
with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_update_fields(self):
"""
@ -902,158 +739,19 @@ def test_docstring_is_not_stripped_by_get_description():
assert descr == formatting.dedent(ExampleDocstringAPIView.__doc__[1:][:-1])
# Views for SchemaGenerationExclusionTests
class ExcludedAPIView(APIView):
schema = None
def get(self, request, *args, **kwargs):
pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed')
class SchemaGenerationExclusionTests(TestCase):
def setUp(self):
self.patterns = [
url('^excluded-cbv/$', ExcludedAPIView.as_view()),
url('^excluded-fbv/$', excluded_fbv),
url('^included-fbv/$', included_fbv),
]
def test_schema_generator_excludes_correctly(self):
"""Schema should not include excluded views"""
generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
url='',
title='Exclusions',
content={
'included-fbv': {
'list': coreapi.Link(url='/included-fbv/', action='get')
}
}
)
assert len(schema.data) == 1
assert 'included-fbv' in schema.data
assert schema == expected
def test_endpoint_enumerator_excludes_correctly(self):
"""It is responsibility of EndpointEnumerator to exclude views"""
inspector = EndpointEnumerator(self.patterns)
endpoints = inspector.get_api_endpoints()
assert len(endpoints) == 1
path, method, callback = endpoints[0]
assert path == '/included-fbv/'
def test_should_include_endpoint_excludes_correctly(self):
"""This is the specific method that should handle the exclusion"""
inspector = EndpointEnumerator(self.patterns)
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
for pattern in self.patterns]
should_include = [
inspector.should_include_endpoint(*pair) for pair in pairs
]
expected = [False, False, True]
assert should_include == expected
def test_deprecations(self):
with pytest.warns(DeprecationWarning) as record:
@api_view(["GET"], exclude_from_schema=True)
def view(request):
pass
assert len(record) == 1
assert str(record[0].message) == (
"The `exclude_from_schema` argument to `api_view` is deprecated. "
"Use the `schema` decorator instead, passing `None`."
)
class OldFashionedExcludedView(APIView):
exclude_from_schema = True
def get(self, request, *args, **kwargs):
pass
patterns = [
url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()),
]
inspector = EndpointEnumerator(patterns)
with pytest.warns(DeprecationWarning) as record:
inspector.get_api_endpoints()
assert len(record) == 1
assert str(record[0].message) == (
"The `OldFashionedExcludedView.exclude_from_schema` attribute is "
"deprecated. Set `schema = None` instead."
)
@api_view(["GET"])
def simple_fbv(request):
pass
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
fields = "__all__"
class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView):
queryset = BasicModel.objects.all()
serializer_class = BasicModelSerializer
class BasicNamingCollisionView(generics.RetrieveAPIView):
queryset = BasicModel.objects.all()
class NamingCollisionViewSet(GenericViewSet):
"""
Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/
"""
permision_class = ()
@action(detail=False)
def detail(self, request):
return {}
@action(detail=False, url_path='detail/export')
def detail_export(self, request):
return {}
naming_collisions_router = SimpleRouter()
naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision")
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestURLNamingCollisions(TestCase):
"""
Ref: https://github.com/encode/django-rest-framework/issues/4704
"""
@api_view(["GET"])
def simple_fbv(request):
pass
def test_manually_routing_nested_routes(self):
patterns = [
url(r'^test', simple_fbv),
url(r'^test/list/', simple_fbv),
url(r'^test', self.simple_fbv),
url(r'^test/list/', self.simple_fbv),
]
generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
@ -1088,6 +786,15 @@ class TestURLNamingCollisions(TestCase):
assert loc[key].url == url
def test_manually_routing_generic_view(self):
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
fields = "__all__"
class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView):
queryset = BasicModel.objects.all()
serializer_class = BasicModelSerializer
patterns = [
url(r'^test', NamingCollisionView.as_view()),
url(r'^test/retrieve/', NamingCollisionView.as_view()),
@ -1111,6 +818,23 @@ class TestURLNamingCollisions(TestCase):
self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0'))
def test_from_router(self):
class NamingCollisionViewSet(GenericViewSet):
"""
Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/
"""
permision_class = ()
@action(detail=False)
def detail(self, request):
return {}
@action(detail=False, url_path='detail/export')
def detail_export(self, request):
return {}
naming_collisions_router = SimpleRouter()
naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision")
patterns = [
url(r'from-router', include(naming_collisions_router.urls)),
]
@ -1143,6 +867,9 @@ class TestURLNamingCollisions(TestCase):
assert schema == expected
def test_url_under_same_key_not_replaced(self):
class BasicNamingCollisionView(generics.RetrieveAPIView):
queryset = BasicModel.objects.all()
patterns = [
url(r'example/(?P<pk>\d+)/$', BasicNamingCollisionView.as_view()),
url(r'example/(?P<slug>\w+)/$', BasicNamingCollisionView.as_view()),
@ -1157,8 +884,8 @@ class TestURLNamingCollisions(TestCase):
def test_url_under_same_key_not_replaced_another(self):
patterns = [
url(r'^test/list/', simple_fbv),
url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
url(r'^test/list/', self.simple_fbv),
url(r'^test/(?P<pk>\d+)/list/', self.simple_fbv),
]
generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)

View File

@ -0,0 +1,112 @@
import unittest
import pytest
from django.conf.urls import url
from django.test import TestCase
from rest_framework.compat import coreapi, get_regex_pattern
from rest_framework.decorators import api_view, schema
from rest_framework.schemas.generators import (
EndpointEnumerator, SchemaGenerator
)
from rest_framework.views import APIView
class EndpointExclusionTests(TestCase):
class ExcludedAPIView(APIView):
schema = None
def get(self, request, *args, **kwargs):
pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
def setUp(self):
self.patterns = [
url('^excluded-cbv/$', self.ExcludedAPIView.as_view()),
url('^excluded-fbv/$', self.excluded_fbv),
url('^included-fbv/$', self.included_fbv),
]
@unittest.skipUnless(coreapi, 'coreapi is not installed')
def test_schema_generator_excludes_correctly(self):
"""Schema should not include excluded views"""
generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
url='',
title='Exclusions',
content={
'included-fbv': {
'list': coreapi.Link(url='/included-fbv/', action='get')
}
}
)
assert len(schema.data) == 1
assert 'included-fbv' in schema.data
assert schema == expected
def test_endpoint_enumerator_excludes_correctly(self):
"""It is responsibility of EndpointEnumerator to exclude views"""
inspector = EndpointEnumerator(self.patterns)
endpoints = inspector.get_api_endpoints()
assert len(endpoints) == 1
path, method, callback = endpoints[0]
assert path == '/included-fbv/'
def test_should_include_endpoint_excludes_correctly(self):
"""This is the specific method that should handle the exclusion"""
inspector = EndpointEnumerator(self.patterns)
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
for pattern in self.patterns]
should_include = [
inspector.should_include_endpoint(*pair) for pair in pairs
]
expected = [False, False, True]
assert should_include == expected
def test_deprecations(self):
with pytest.warns(DeprecationWarning) as record:
@api_view(["GET"], exclude_from_schema=True)
def view(request):
pass
assert len(record) == 1
assert str(record[0].message) == (
"The `exclude_from_schema` argument to `api_view` is deprecated. "
"Use the `schema` decorator instead, passing `None`."
)
class OldFashionedExcludedView(APIView):
exclude_from_schema = True
def get(self, request, *args, **kwargs):
pass
patterns = [
url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()),
]
inspector = EndpointEnumerator(patterns)
with pytest.warns(DeprecationWarning) as record:
inspector.get_api_endpoints()
assert len(record) == 1
assert str(record[0].message) == (
"The `OldFashionedExcludedView.exclude_from_schema` attribute is "
"deprecated. Set `schema = None` instead."
)

View File

@ -0,0 +1,38 @@
import pytest
from django.test import TestCase, override_settings
from rest_framework.schemas.inspectors import AutoSchema, ViewInspector
from rest_framework.views import APIView
class CustomViewInspector(ViewInspector):
"""A dummy ViewInspector subclass"""
pass
class TestViewInspector(TestCase):
"""
Tests for the descriptor behaviour of ViewInspector
(and subclasses.)
"""
def test_apiview_schema_descriptor(self):
view = APIView()
assert hasattr(view, 'schema')
assert isinstance(view.schema, AutoSchema)
def test_set_custom_inspector_class_on_view(self):
class CustomView(APIView):
schema = CustomViewInspector()
view = CustomView()
assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_view_inspector_descriptor.CustomViewInspector'}):
view = APIView()
assert isinstance(view.schema, CustomViewInspector)
def test_get_link_requires_instance(self):
descriptor = APIView.schema # Accessed from class
with pytest.raises(AssertionError):
descriptor.get_link(None, None, None)

139
tests/schemas/views.py Normal file
View File

@ -0,0 +1,139 @@
from django.core.exceptions import PermissionDenied
from django.http import Http404
from rest_framework import filters, pagination, permissions, serializers
from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
# Simple APIViews:
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
# Classes for ExampleViewSet
class ExamplePagination(pagination.PageNumberPagination):
page_size = 100
page_size_query_param = 'page_size'
class EmptySerializer(serializers.Serializer):
pass
class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True, help_text='A field description')
b = serializers.CharField(required=False)
read_only = serializers.CharField(read_only=True)
hidden = serializers.HiddenField(default='hello')
class AnotherSerializerWithDictField(serializers.Serializer):
a = serializers.DictField()
class AnotherSerializerWithListFields(serializers.Serializer):
a = serializers.ListField(child=serializers.IntegerField())
b = serializers.ListSerializer(child=serializers.CharField())
class AnotherSerializer(serializers.Serializer):
c = serializers.CharField(required=True)
d = serializers.CharField(required=False)
class ExampleViewSet(ModelViewSet):
pagination_class = ExamplePagination
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [filters.OrderingFilter]
serializer_class = ExampleSerializer
@action(methods=['post'], detail=True, serializer_class=AnotherSerializer)
def custom_action(self, request, pk):
"""
A description of custom action.
"""
raise NotImplementedError
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
def custom_action_with_dict_field(self, request, pk):
"""
A custom action using a dict field in the serializer.
"""
raise NotImplementedError
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
def custom_action_with_list_fields(self, request, pk):
"""
A custom action using both list field and list serializer in the serializer.
"""
raise NotImplementedError
@action(detail=False)
def custom_list_action(self, request):
raise NotImplementedError
@action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
"""Custom description."""
raise NotImplementedError
@custom_list_action_multiple_methods.mapping.delete
def custom_list_action_multiple_methods_delete(self, request):
"""Deletion description."""
raise NotImplementedError
@action(detail=False, schema=None)
def excluded_action(self, request):
pass
def get_serializer(self, *args, **kwargs):
assert self.request
assert self.action
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
# ExampleViewSet subclasses
class DenyAllUsingHttp404(permissions.BasePermission):
def has_permission(self, request, view):
raise Http404()
def has_object_permission(self, request, view, obj):
raise Http404()
class DenyAllUsingPermissionDenied(permissions.BasePermission):
def has_permission(self, request, view):
raise PermissionDenied()
def has_object_permission(self, request, view, obj):
raise PermissionDenied()
class Http404ExampleViewSet(ExampleViewSet):
permission_classes = [DenyAllUsingHttp404]
class PermissionDeniedExampleViewSet(ExampleViewSet):
permission_classes = [DenyAllUsingPermissionDenied]
class MethodLimitedViewSet(ExampleViewSet):
permission_classes = []
http_method_names = ['get', 'head', 'options']