mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-01 11:00:13 +03:00
Merge 9590e11101
into c9d2bbcead
This commit is contained in:
commit
1903c3f41b
0
rest_framework/management/__init__.py
Normal file
0
rest_framework/management/__init__.py
Normal file
0
rest_framework/management/commands/__init__.py
Normal file
0
rest_framework/management/commands/__init__.py
Normal file
46
rest_framework/management/commands/generate_schema.py
Normal file
46
rest_framework/management/commands/generate_schema.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from django.core.management.base import BaseCommand
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.renderers import CoreJSONRenderer, OpenAPIRenderer
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generates configured API schema for project."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
# TODO
|
||||
# SchemaGenerator allows passing:
|
||||
#
|
||||
# - title
|
||||
# - url
|
||||
# - description
|
||||
# - urlconf
|
||||
# - patterns
|
||||
#
|
||||
# Don't particularly want to pass these on the command-line.
|
||||
# conf file?
|
||||
#
|
||||
# Other options to consider:
|
||||
# - indent
|
||||
# - ...
|
||||
pass
|
||||
|
||||
def handle(self, *args, **options):
|
||||
assert coreapi is not None, 'coreapi must be installed.'
|
||||
|
||||
generator_class = api_settings.DEFAULT_SCHEMA_GENERATOR_CLASS()
|
||||
generator = generator_class()
|
||||
|
||||
schema = generator.get_schema(request=None, public=True)
|
||||
|
||||
renderer = self.get_renderer('openapi')
|
||||
output = renderer.render(schema)
|
||||
|
||||
self.stdout.write(output)
|
||||
|
||||
def get_renderer(self, format):
|
||||
return {
|
||||
'corejson': CoreJSONRenderer(),
|
||||
'openapi': OpenAPIRenderer()
|
||||
}
|
|
@ -9,6 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API.
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import base64
|
||||
import urllib.parse as urlparse
|
||||
from collections import OrderedDict
|
||||
|
||||
from django import forms
|
||||
|
@ -24,7 +25,7 @@ from django.utils.html import mark_safe
|
|||
|
||||
from rest_framework import VERSION, exceptions, serializers, status
|
||||
from rest_framework.compat import (
|
||||
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
|
||||
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema,
|
||||
pygments_css
|
||||
)
|
||||
from rest_framework.exceptions import ParseError
|
||||
|
@ -932,3 +933,95 @@ class CoreJSONRenderer(BaseRenderer):
|
|||
indent = bool(renderer_context.get('indent', 0))
|
||||
codec = coreapi.codecs.CoreJSONCodec()
|
||||
return codec.dump(data, indent=indent)
|
||||
|
||||
|
||||
class OpenAPIRenderer:
|
||||
CLASS_TO_TYPENAME = {
|
||||
coreschema.Object: 'object',
|
||||
coreschema.Array: 'array',
|
||||
coreschema.Number: 'number',
|
||||
coreschema.Integer: 'integer',
|
||||
coreschema.String: 'string',
|
||||
coreschema.Boolean: 'boolean',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.'
|
||||
|
||||
def get_schema(self, instance):
|
||||
schema = {}
|
||||
if instance.__class__ in self.CLASS_TO_TYPENAME:
|
||||
schema['type'] = self.CLASS_TO_TYPENAME[instance.__class__]
|
||||
schema['title'] = instance.title,
|
||||
schema['description'] = instance.description
|
||||
if hasattr(instance, 'enum'):
|
||||
schema['enum'] = instance.enum
|
||||
return schema
|
||||
|
||||
def get_parameters(self, link):
|
||||
parameters = []
|
||||
for field in link.fields:
|
||||
if field.location not in ['path', 'query']:
|
||||
continue
|
||||
parameter = {
|
||||
'name': field.name,
|
||||
'in': field.location,
|
||||
}
|
||||
if field.required:
|
||||
parameter['required'] = True
|
||||
if field.description:
|
||||
parameter['description'] = field.description
|
||||
if field.schema:
|
||||
parameter['schema'] = self.get_schema(field.schema)
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
def get_operation(self, link, name, tag):
|
||||
operation_id = "%s_%s" % (tag, name) if tag else name
|
||||
parameters = self.get_parameters(link)
|
||||
|
||||
operation = {
|
||||
'operationId': operation_id,
|
||||
}
|
||||
if link.title:
|
||||
operation['summary'] = link.title
|
||||
if link.description:
|
||||
operation['description'] = link.description
|
||||
if parameters:
|
||||
operation['parameters'] = parameters
|
||||
if tag:
|
||||
operation['tags'] = [tag]
|
||||
return operation
|
||||
|
||||
def get_paths(self, document):
|
||||
paths = {}
|
||||
|
||||
tag = None
|
||||
for name, link in document.links.items():
|
||||
path = urlparse.urlparse(link.url).path
|
||||
method = link.action.lower()
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method] = self.get_operation(link, name, tag=tag)
|
||||
|
||||
for tag, section in document.data.items():
|
||||
for name, link in section.links.items():
|
||||
path = urlparse.urlparse(link.url).path
|
||||
method = link.action.lower()
|
||||
paths.setdefault(path, {})
|
||||
paths[path][method] = self.get_operation(link, name, tag=tag)
|
||||
|
||||
return paths
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
return json.dumps({
|
||||
'openapi': '3.0.0',
|
||||
'info': {
|
||||
'version': '',
|
||||
'title': data.title,
|
||||
'description': data.description
|
||||
},
|
||||
'servers': [{
|
||||
'url': data.url
|
||||
}],
|
||||
'paths': self.get_paths(data)
|
||||
}, indent=4)
|
||||
|
|
|
@ -241,35 +241,18 @@ class EndpointEnumerator(object):
|
|||
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
class BaseSchemaGenerator(object):
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
|
@ -279,36 +262,15 @@ class SchemaGenerator(object):
|
|||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
def _initialise_endpoints(self):
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
def get_links(self, request=None):
|
||||
def _get_paths_and_endpoints(self, request):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
Generate (path, method, view) given (path, method, callback) for paths.
|
||||
"""
|
||||
links = LinkNode()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
|
@ -317,22 +279,48 @@ class SchemaGenerator(object):
|
|||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
return paths, view_endpoints
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
return links
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
|
@ -365,29 +353,6 @@ class SchemaGenerator(object):
|
|||
prefixes.append('/' + prefix + '/')
|
||||
return common_path(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls(**getattr(callback, 'initkwargs', {}))
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
|
@ -401,23 +366,77 @@ class SchemaGenerator(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
"""
|
||||
Original CoreAPI version.
|
||||
"""
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
assert coreschema, '`coreschema` must be installed for schema support.'
|
||||
|
||||
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
links = LinkNode()
|
||||
|
||||
paths, view_endpoints = self._get_paths_and_endpoints(request)
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
if not paths:
|
||||
return None
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
||||
return links
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
self._initialise_endpoints()
|
||||
|
||||
links = self.get_links(None if public else request)
|
||||
if not links:
|
||||
return None
|
||||
|
||||
url = self.url
|
||||
if not url and request is not None:
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
distribute_links(links)
|
||||
return coreapi.Document(
|
||||
title=self.title, description=self.description,
|
||||
url=url, content=links
|
||||
)
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
|
|
|
@ -174,20 +174,6 @@ class ViewInspector(object):
|
|||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
|
@ -208,6 +194,17 @@ class AutoSchema(ViewInspector):
|
|||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
|
@ -501,3 +498,44 @@ class DefaultSchema(ViewInspector):
|
|||
inspector = inspector_class()
|
||||
inspector.view = instance
|
||||
return inspector
|
||||
|
||||
|
||||
class OpenAPIAutoSchema(ViewInspector):
|
||||
|
||||
def get_operation(self, path, method):
|
||||
return {
|
||||
'parameters': self.get_path_parameters(path, method),
|
||||
}
|
||||
|
||||
def get_path_parameters(self, path, method):
|
||||
"""
|
||||
Return a list of parameters from templated path variables.
|
||||
"""
|
||||
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
|
||||
|
||||
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
|
||||
parameters = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
description = ''
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except Exception:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
parameter = {
|
||||
"name": variable,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": description,
|
||||
}
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
|
|
@ -57,6 +57,7 @@ DEFAULTS = {
|
|||
|
||||
# Schema
|
||||
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
|
||||
'DEFAULT_SCHEMA_GENERATOR_CLASS': 'rest_framework.schemas.generators.SchemaGenerator',
|
||||
|
||||
# Throttling
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
|
@ -144,6 +145,7 @@ IMPORT_STRINGS = (
|
|||
'DEFAULT_PAGINATION_CLASS',
|
||||
'DEFAULT_FILTER_BACKENDS',
|
||||
'DEFAULT_SCHEMA_CLASS',
|
||||
'DEFAULT_SCHEMA_GENERATOR_CLASS',
|
||||
'EXCEPTION_HANDLER',
|
||||
'TEST_REQUEST_RENDERER_CLASSES',
|
||||
'UNAUTHENTICATED_USER',
|
||||
|
|
0
tests/schemas/__init__.py
Normal file
0
tests/schemas/__init__.py
Normal file
|
@ -2,15 +2,11 @@ import unittest
|
|||
|
||||
import pytest
|
||||
from django.conf.urls import include, url
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework import (
|
||||
filters, generics, pagination, permissions, serializers
|
||||
)
|
||||
from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path
|
||||
from rest_framework.decorators import action, api_view, schema
|
||||
from rest_framework import filters, generics, serializers
|
||||
from rest_framework.compat import coreapi, coreschema, path
|
||||
from rest_framework.decorators import action, api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.routers import DefaultRouter, SimpleRouter
|
||||
from rest_framework.schemas import (
|
||||
|
@ -24,7 +20,8 @@ from rest_framework.utils import formatting
|
|||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
|
||||
from .models import BasicModel, ForeignKeySource
|
||||
from . import views
|
||||
from ..models import BasicModel, ForeignKeySource
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
@ -34,87 +31,6 @@ class MockUser(object):
|
|||
return True
|
||||
|
||||
|
||||
class ExamplePagination(pagination.PageNumberPagination):
|
||||
page_size = 100
|
||||
page_size_query_param = 'page_size'
|
||||
|
||||
|
||||
class EmptySerializer(serializers.Serializer):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
a = serializers.CharField(required=True, help_text='A field description')
|
||||
b = serializers.CharField(required=False)
|
||||
read_only = serializers.CharField(read_only=True)
|
||||
hidden = serializers.HiddenField(default='hello')
|
||||
|
||||
|
||||
class AnotherSerializerWithDictField(serializers.Serializer):
|
||||
a = serializers.DictField()
|
||||
|
||||
|
||||
class AnotherSerializerWithListFields(serializers.Serializer):
|
||||
a = serializers.ListField(child=serializers.IntegerField())
|
||||
b = serializers.ListSerializer(child=serializers.CharField())
|
||||
|
||||
|
||||
class AnotherSerializer(serializers.Serializer):
|
||||
c = serializers.CharField(required=True)
|
||||
d = serializers.CharField(required=False)
|
||||
|
||||
|
||||
class ExampleViewSet(ModelViewSet):
|
||||
pagination_class = ExamplePagination
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializer)
|
||||
def custom_action(self, request, pk):
|
||||
"""
|
||||
A description of custom action.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
|
||||
def custom_action_with_dict_field(self, request, pk):
|
||||
"""
|
||||
A custom action using a dict field in the serializer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
|
||||
def custom_action_with_list_fields(self, request, pk):
|
||||
"""
|
||||
A custom action using both list field and list serializer in the serializer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(detail=False)
|
||||
def custom_list_action(self, request):
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
|
||||
def custom_list_action_multiple_methods(self, request):
|
||||
"""Custom description."""
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_list_action_multiple_methods.mapping.delete
|
||||
def custom_list_action_multiple_methods_delete(self, request):
|
||||
"""Deletion description."""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(detail=False, schema=None)
|
||||
def excluded_action(self, request):
|
||||
pass
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
assert self.request
|
||||
assert self.action
|
||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||
|
||||
|
||||
if coreapi:
|
||||
schema_view = get_schema_view(title='Example API')
|
||||
else:
|
||||
|
@ -122,7 +38,7 @@ else:
|
|||
pass
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register('example', ExampleViewSet, basename='example')
|
||||
router.register('example', views.ExampleViewSet, basename='example')
|
||||
urlpatterns = [
|
||||
url(r'^$', schema_view),
|
||||
url(r'^', include(router.urls))
|
||||
|
@ -130,7 +46,7 @@ urlpatterns = [
|
|||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_schemas')
|
||||
@override_settings(ROOT_URLCONF='tests.schemas.test_coreapi')
|
||||
class TestRouterGeneratedSchema(TestCase):
|
||||
def test_anonymous_request(self):
|
||||
client = APIClient()
|
||||
|
@ -299,61 +215,13 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
assert response.data == expected
|
||||
|
||||
|
||||
class DenyAllUsingHttp404(permissions.BasePermission):
|
||||
|
||||
def has_permission(self, request, view):
|
||||
raise Http404()
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
raise Http404()
|
||||
|
||||
|
||||
class DenyAllUsingPermissionDenied(permissions.BasePermission):
|
||||
|
||||
def has_permission(self, request, view):
|
||||
raise PermissionDenied()
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
raise PermissionDenied()
|
||||
|
||||
|
||||
class Http404ExampleViewSet(ExampleViewSet):
|
||||
permission_classes = [DenyAllUsingHttp404]
|
||||
|
||||
|
||||
class PermissionDeniedExampleViewSet(ExampleViewSet):
|
||||
permission_classes = [DenyAllUsingPermissionDenied]
|
||||
|
||||
|
||||
class MethodLimitedViewSet(ExampleViewSet):
|
||||
permission_classes = []
|
||||
http_method_names = ['get', 'head', 'options']
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
class TestSchemaGenerator(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^example/?$', ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -404,9 +272,9 @@ class TestSchemaGenerator(TestCase):
|
|||
class TestSchemaGeneratorDjango2(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
path('example/', ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', ExampleDetailView.as_view()),
|
||||
path('example/', views.ExampleListView.as_view()),
|
||||
path('example/<int:pk>/', views.ExampleDetailView.as_view()),
|
||||
path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -456,9 +324,9 @@ class TestSchemaGeneratorDjango2(TestCase):
|
|||
class TestSchemaGeneratorNotAtRoot(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^api/v1/example/?$', ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
|
||||
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
|
@ -509,7 +377,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
|
|||
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
router.register('example1', MethodLimitedViewSet, basename='example1')
|
||||
router.register('example1', views.MethodLimitedViewSet, basename='example1')
|
||||
self.patterns = [
|
||||
url(r'^', include(router.urls))
|
||||
]
|
||||
|
@ -566,10 +434,10 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
|
|||
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
|
||||
def setUp(self):
|
||||
router = DefaultRouter()
|
||||
router.register('example1', Http404ExampleViewSet, basename='example1')
|
||||
router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
|
||||
router.register('example1', views.Http404ExampleViewSet, basename='example1')
|
||||
router.register('example2', views.PermissionDeniedExampleViewSet, basename='example2')
|
||||
self.patterns = [
|
||||
url('^example/?$', ExampleListView.as_view()),
|
||||
url('^example/?$', views.ExampleListView.as_view()),
|
||||
url(r'^', include(router.urls))
|
||||
]
|
||||
|
||||
|
@ -597,29 +465,25 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
|
|||
assert schema == expected
|
||||
|
||||
|
||||
class ForeignKeySourceSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ForeignKeySource
|
||||
fields = ('id', 'name', 'target')
|
||||
|
||||
|
||||
class ForeignKeySourceView(generics.CreateAPIView):
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
serializer_class = ForeignKeySourceSerializer
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
class TestSchemaGeneratorWithForeignKey(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url(r'^example/?$', ForeignKeySourceView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
"""
|
||||
Ensure that AutoField foreign keys are output as Integer.
|
||||
"""
|
||||
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
|
||||
class ForeignKeySourceSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ForeignKeySource
|
||||
fields = ('id', 'name', 'target')
|
||||
|
||||
class ForeignKeySourceView(generics.CreateAPIView):
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
serializer_class = ForeignKeySourceSerializer
|
||||
|
||||
patterns = [
|
||||
url(r'^example/?$', ForeignKeySourceView.as_view()),
|
||||
]
|
||||
generator = SchemaGenerator(title='Example API', patterns=patterns)
|
||||
schema = generator.get_schema()
|
||||
|
||||
expected = coreapi.Document(
|
||||
|
@ -653,35 +517,8 @@ class Test4605Regression(TestCase):
|
|||
assert prefix == '/'
|
||||
|
||||
|
||||
class CustomViewInspector(AutoSchema):
|
||||
"""A dummy AutoSchema subclass"""
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoSchema(TestCase):
|
||||
|
||||
def test_apiview_schema_descriptor(self):
|
||||
view = APIView()
|
||||
assert hasattr(view, 'schema')
|
||||
assert isinstance(view.schema, AutoSchema)
|
||||
|
||||
def test_set_custom_inspector_class_on_view(self):
|
||||
class CustomView(APIView):
|
||||
schema = CustomViewInspector()
|
||||
|
||||
view = CustomView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_set_custom_inspector_class_via_settings(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
|
||||
view = APIView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_get_link_requires_instance(self):
|
||||
descriptor = APIView.schema # Accessed from class
|
||||
with pytest.raises(AssertionError):
|
||||
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
def test_update_fields(self):
|
||||
"""
|
||||
|
@ -902,158 +739,19 @@ def test_docstring_is_not_stripped_by_get_description():
|
|||
assert descr == formatting.dedent(ExampleDocstringAPIView.__doc__[1:][:-1])
|
||||
|
||||
|
||||
# Views for SchemaGenerationExclusionTests
|
||||
class ExcludedAPIView(APIView):
|
||||
schema = None
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
class SchemaGenerationExclusionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url('^excluded-cbv/$', ExcludedAPIView.as_view()),
|
||||
url('^excluded-fbv/$', excluded_fbv),
|
||||
url('^included-fbv/$', included_fbv),
|
||||
]
|
||||
|
||||
def test_schema_generator_excludes_correctly(self):
|
||||
"""Schema should not include excluded views"""
|
||||
generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
|
||||
schema = generator.get_schema()
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
title='Exclusions',
|
||||
content={
|
||||
'included-fbv': {
|
||||
'list': coreapi.Link(url='/included-fbv/', action='get')
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert len(schema.data) == 1
|
||||
assert 'included-fbv' in schema.data
|
||||
assert schema == expected
|
||||
|
||||
def test_endpoint_enumerator_excludes_correctly(self):
|
||||
"""It is responsibility of EndpointEnumerator to exclude views"""
|
||||
inspector = EndpointEnumerator(self.patterns)
|
||||
endpoints = inspector.get_api_endpoints()
|
||||
|
||||
assert len(endpoints) == 1
|
||||
path, method, callback = endpoints[0]
|
||||
assert path == '/included-fbv/'
|
||||
|
||||
def test_should_include_endpoint_excludes_correctly(self):
|
||||
"""This is the specific method that should handle the exclusion"""
|
||||
inspector = EndpointEnumerator(self.patterns)
|
||||
|
||||
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
|
||||
pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
|
||||
for pattern in self.patterns]
|
||||
|
||||
should_include = [
|
||||
inspector.should_include_endpoint(*pair) for pair in pairs
|
||||
]
|
||||
|
||||
expected = [False, False, True]
|
||||
|
||||
assert should_include == expected
|
||||
|
||||
def test_deprecations(self):
|
||||
with pytest.warns(DeprecationWarning) as record:
|
||||
@api_view(["GET"], exclude_from_schema=True)
|
||||
def view(request):
|
||||
pass
|
||||
|
||||
assert len(record) == 1
|
||||
assert str(record[0].message) == (
|
||||
"The `exclude_from_schema` argument to `api_view` is deprecated. "
|
||||
"Use the `schema` decorator instead, passing `None`."
|
||||
)
|
||||
|
||||
class OldFashionedExcludedView(APIView):
|
||||
exclude_from_schema = True
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()),
|
||||
]
|
||||
|
||||
inspector = EndpointEnumerator(patterns)
|
||||
with pytest.warns(DeprecationWarning) as record:
|
||||
inspector.get_api_endpoints()
|
||||
|
||||
assert len(record) == 1
|
||||
assert str(record[0].message) == (
|
||||
"The `OldFashionedExcludedView.exclude_from_schema` attribute is "
|
||||
"deprecated. Set `schema = None` instead."
|
||||
)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView):
|
||||
queryset = BasicModel.objects.all()
|
||||
serializer_class = BasicModelSerializer
|
||||
|
||||
|
||||
class BasicNamingCollisionView(generics.RetrieveAPIView):
|
||||
queryset = BasicModel.objects.all()
|
||||
|
||||
|
||||
class NamingCollisionViewSet(GenericViewSet):
|
||||
"""
|
||||
Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/
|
||||
"""
|
||||
permision_class = ()
|
||||
|
||||
@action(detail=False)
|
||||
def detail(self, request):
|
||||
return {}
|
||||
|
||||
@action(detail=False, url_path='detail/export')
|
||||
def detail_export(self, request):
|
||||
return {}
|
||||
|
||||
|
||||
naming_collisions_router = SimpleRouter()
|
||||
naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
|
||||
class TestURLNamingCollisions(TestCase):
|
||||
"""
|
||||
Ref: https://github.com/encode/django-rest-framework/issues/4704
|
||||
"""
|
||||
@api_view(["GET"])
|
||||
def simple_fbv(request):
|
||||
pass
|
||||
|
||||
def test_manually_routing_nested_routes(self):
|
||||
patterns = [
|
||||
url(r'^test', simple_fbv),
|
||||
url(r'^test/list/', simple_fbv),
|
||||
url(r'^test', self.simple_fbv),
|
||||
url(r'^test/list/', self.simple_fbv),
|
||||
]
|
||||
|
||||
generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
|
||||
|
@ -1088,6 +786,15 @@ class TestURLNamingCollisions(TestCase):
|
|||
assert loc[key].url == url
|
||||
|
||||
def test_manually_routing_generic_view(self):
|
||||
class BasicModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
fields = "__all__"
|
||||
|
||||
class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView):
|
||||
queryset = BasicModel.objects.all()
|
||||
serializer_class = BasicModelSerializer
|
||||
|
||||
patterns = [
|
||||
url(r'^test', NamingCollisionView.as_view()),
|
||||
url(r'^test/retrieve/', NamingCollisionView.as_view()),
|
||||
|
@ -1111,6 +818,23 @@ class TestURLNamingCollisions(TestCase):
|
|||
self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0'))
|
||||
|
||||
def test_from_router(self):
|
||||
class NamingCollisionViewSet(GenericViewSet):
|
||||
"""
|
||||
Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/
|
||||
"""
|
||||
permision_class = ()
|
||||
|
||||
@action(detail=False)
|
||||
def detail(self, request):
|
||||
return {}
|
||||
|
||||
@action(detail=False, url_path='detail/export')
|
||||
def detail_export(self, request):
|
||||
return {}
|
||||
|
||||
naming_collisions_router = SimpleRouter()
|
||||
naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision")
|
||||
|
||||
patterns = [
|
||||
url(r'from-router', include(naming_collisions_router.urls)),
|
||||
]
|
||||
|
@ -1143,6 +867,9 @@ class TestURLNamingCollisions(TestCase):
|
|||
assert schema == expected
|
||||
|
||||
def test_url_under_same_key_not_replaced(self):
|
||||
class BasicNamingCollisionView(generics.RetrieveAPIView):
|
||||
queryset = BasicModel.objects.all()
|
||||
|
||||
patterns = [
|
||||
url(r'example/(?P<pk>\d+)/$', BasicNamingCollisionView.as_view()),
|
||||
url(r'example/(?P<slug>\w+)/$', BasicNamingCollisionView.as_view()),
|
||||
|
@ -1157,8 +884,8 @@ class TestURLNamingCollisions(TestCase):
|
|||
def test_url_under_same_key_not_replaced_another(self):
|
||||
|
||||
patterns = [
|
||||
url(r'^test/list/', simple_fbv),
|
||||
url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
|
||||
url(r'^test/list/', self.simple_fbv),
|
||||
url(r'^test/(?P<pk>\d+)/list/', self.simple_fbv),
|
||||
]
|
||||
|
||||
generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
|
112
tests/schemas/test_endpoint_enumerator.py
Normal file
112
tests/schemas/test_endpoint_enumerator.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from django.conf.urls import url
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework.compat import coreapi, get_regex_pattern
|
||||
from rest_framework.decorators import api_view, schema
|
||||
from rest_framework.schemas.generators import (
|
||||
EndpointEnumerator, SchemaGenerator
|
||||
)
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class EndpointExclusionTests(TestCase):
|
||||
class ExcludedAPIView(APIView):
|
||||
schema = None
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(None)
|
||||
def excluded_fbv(request):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
def included_fbv(request):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url('^excluded-cbv/$', self.ExcludedAPIView.as_view()),
|
||||
url('^excluded-fbv/$', self.excluded_fbv),
|
||||
url('^included-fbv/$', self.included_fbv),
|
||||
]
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
def test_schema_generator_excludes_correctly(self):
|
||||
"""Schema should not include excluded views"""
|
||||
generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
|
||||
schema = generator.get_schema()
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
title='Exclusions',
|
||||
content={
|
||||
'included-fbv': {
|
||||
'list': coreapi.Link(url='/included-fbv/', action='get')
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert len(schema.data) == 1
|
||||
assert 'included-fbv' in schema.data
|
||||
assert schema == expected
|
||||
|
||||
def test_endpoint_enumerator_excludes_correctly(self):
|
||||
"""It is responsibility of EndpointEnumerator to exclude views"""
|
||||
inspector = EndpointEnumerator(self.patterns)
|
||||
endpoints = inspector.get_api_endpoints()
|
||||
|
||||
assert len(endpoints) == 1
|
||||
path, method, callback = endpoints[0]
|
||||
assert path == '/included-fbv/'
|
||||
|
||||
def test_should_include_endpoint_excludes_correctly(self):
|
||||
"""This is the specific method that should handle the exclusion"""
|
||||
inspector = EndpointEnumerator(self.patterns)
|
||||
|
||||
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
|
||||
pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
|
||||
for pattern in self.patterns]
|
||||
|
||||
should_include = [
|
||||
inspector.should_include_endpoint(*pair) for pair in pairs
|
||||
]
|
||||
|
||||
expected = [False, False, True]
|
||||
|
||||
assert should_include == expected
|
||||
|
||||
def test_deprecations(self):
|
||||
with pytest.warns(DeprecationWarning) as record:
|
||||
@api_view(["GET"], exclude_from_schema=True)
|
||||
def view(request):
|
||||
pass
|
||||
|
||||
assert len(record) == 1
|
||||
assert str(record[0].message) == (
|
||||
"The `exclude_from_schema` argument to `api_view` is deprecated. "
|
||||
"Use the `schema` decorator instead, passing `None`."
|
||||
)
|
||||
|
||||
class OldFashionedExcludedView(APIView):
|
||||
exclude_from_schema = True
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
patterns = [
|
||||
url('^excluded-old-fashioned/$', OldFashionedExcludedView.as_view()),
|
||||
]
|
||||
|
||||
inspector = EndpointEnumerator(patterns)
|
||||
with pytest.warns(DeprecationWarning) as record:
|
||||
inspector.get_api_endpoints()
|
||||
|
||||
assert len(record) == 1
|
||||
assert str(record[0].message) == (
|
||||
"The `OldFashionedExcludedView.exclude_from_schema` attribute is "
|
||||
"deprecated. Set `schema = None` instead."
|
||||
)
|
38
tests/schemas/test_view_inspector_descriptor.py
Normal file
38
tests/schemas/test_view_inspector_descriptor.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework.schemas.inspectors import AutoSchema, ViewInspector
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class CustomViewInspector(ViewInspector):
|
||||
"""A dummy ViewInspector subclass"""
|
||||
pass
|
||||
|
||||
|
||||
class TestViewInspector(TestCase):
|
||||
"""
|
||||
Tests for the descriptor behaviour of ViewInspector
|
||||
(and subclasses.)
|
||||
"""
|
||||
def test_apiview_schema_descriptor(self):
|
||||
view = APIView()
|
||||
assert hasattr(view, 'schema')
|
||||
assert isinstance(view.schema, AutoSchema)
|
||||
|
||||
def test_set_custom_inspector_class_on_view(self):
|
||||
class CustomView(APIView):
|
||||
schema = CustomViewInspector()
|
||||
|
||||
view = CustomView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_set_custom_inspector_class_via_settings(self):
|
||||
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_view_inspector_descriptor.CustomViewInspector'}):
|
||||
view = APIView()
|
||||
assert isinstance(view.schema, CustomViewInspector)
|
||||
|
||||
def test_get_link_requires_instance(self):
|
||||
descriptor = APIView.schema # Accessed from class
|
||||
with pytest.raises(AssertionError):
|
||||
descriptor.get_link(None, None, None)
|
139
tests/schemas/views.py
Normal file
139
tests/schemas/views.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
|
||||
from rest_framework import filters, pagination, permissions, serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
|
||||
# Simple APIViews:
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Classes for ExampleViewSet
|
||||
class ExamplePagination(pagination.PageNumberPagination):
|
||||
page_size = 100
|
||||
page_size_query_param = 'page_size'
|
||||
|
||||
|
||||
class EmptySerializer(serializers.Serializer):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
a = serializers.CharField(required=True, help_text='A field description')
|
||||
b = serializers.CharField(required=False)
|
||||
read_only = serializers.CharField(read_only=True)
|
||||
hidden = serializers.HiddenField(default='hello')
|
||||
|
||||
|
||||
class AnotherSerializerWithDictField(serializers.Serializer):
|
||||
a = serializers.DictField()
|
||||
|
||||
|
||||
class AnotherSerializerWithListFields(serializers.Serializer):
|
||||
a = serializers.ListField(child=serializers.IntegerField())
|
||||
b = serializers.ListSerializer(child=serializers.CharField())
|
||||
|
||||
|
||||
class AnotherSerializer(serializers.Serializer):
|
||||
c = serializers.CharField(required=True)
|
||||
d = serializers.CharField(required=False)
|
||||
|
||||
|
||||
class ExampleViewSet(ModelViewSet):
|
||||
pagination_class = ExamplePagination
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
serializer_class = ExampleSerializer
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializer)
|
||||
def custom_action(self, request, pk):
|
||||
"""
|
||||
A description of custom action.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
|
||||
def custom_action_with_dict_field(self, request, pk):
|
||||
"""
|
||||
A custom action using a dict field in the serializer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
|
||||
def custom_action_with_list_fields(self, request, pk):
|
||||
"""
|
||||
A custom action using both list field and list serializer in the serializer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(detail=False)
|
||||
def custom_list_action(self, request):
|
||||
raise NotImplementedError
|
||||
|
||||
@action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
|
||||
def custom_list_action_multiple_methods(self, request):
|
||||
"""Custom description."""
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_list_action_multiple_methods.mapping.delete
|
||||
def custom_list_action_multiple_methods_delete(self, request):
|
||||
"""Deletion description."""
|
||||
raise NotImplementedError
|
||||
|
||||
@action(detail=False, schema=None)
|
||||
def excluded_action(self, request):
|
||||
pass
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
assert self.request
|
||||
assert self.action
|
||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||
|
||||
|
||||
# ExampleViewSet subclasses
|
||||
class DenyAllUsingHttp404(permissions.BasePermission):
|
||||
|
||||
def has_permission(self, request, view):
|
||||
raise Http404()
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
raise Http404()
|
||||
|
||||
|
||||
class DenyAllUsingPermissionDenied(permissions.BasePermission):
|
||||
|
||||
def has_permission(self, request, view):
|
||||
raise PermissionDenied()
|
||||
|
||||
def has_object_permission(self, request, view, obj):
|
||||
raise PermissionDenied()
|
||||
|
||||
|
||||
class Http404ExampleViewSet(ExampleViewSet):
|
||||
permission_classes = [DenyAllUsingHttp404]
|
||||
|
||||
|
||||
class PermissionDeniedExampleViewSet(ExampleViewSet):
|
||||
permission_classes = [DenyAllUsingPermissionDenied]
|
||||
|
||||
|
||||
class MethodLimitedViewSet(ExampleViewSet):
|
||||
permission_classes = []
|
||||
http_method_names = ['get', 'head', 'options']
|
Loading…
Reference in New Issue
Block a user