Refactor schema generation to allow per-view customisation (#5354)

* Initial Refactor Step

* Add descriptor class
* call from generator
* proxy back to generator for implementation.

* Move `get_link` to descriptor

* Move `get_description` to descriptor

* Remove need for generator in get_description

* Move get_path_fields to descriptor

* Move `get_serializer_fields` to descriptor

* Move `get_pagination_fields` to descriptor

* Move `get_filter_fields` to descriptor

* Move `get_encoding` to descriptor.

* Pass just `url` from SchemaGenerator to descriptor

* Make `view` a property

Encapsulates check for a view instance.

* Adjust API Reference docs

* Add `ManualSchema` class

* Refactor to `ViewInspector` plus `AutoSchema`

The interface then is **just** `get_link()`

* Add `manual_fields` kwarg to AutoSchema

* Add schema decorator for FBVs

* Adjust comments

* Docs: Provide full params in example

Ref feedback b52e372f8f (r137254795)

* Add docstring for ViewInstpector.__get__ descriptor method.

Ref https://github.com/encode/django-rest-framework/pull/5354#discussion_r137265022

* Make `schemas` a package.

* Split generators, inspectors, views.

* Adjust imports

* Rename to EndpointEnumerator

* Adjust ManualSchema to take `fields`

… and `description`.

Allows `url` and `action` to remain dynamic

* Add package/module docstrings
This commit is contained in:
Carlton Gibson 2017-09-14 10:46:34 +02:00 committed by Tom Christie
parent 5ea810d526
commit d54df8c438
12 changed files with 868 additions and 383 deletions

View File

@ -10,7 +10,14 @@ API schemas are a useful tool that allow for a range of use cases, including
generating reference documentation, or driving dynamic client libraries that generating reference documentation, or driving dynamic client libraries that
can interact with your API. can interact with your API.
## Representing schemas internally ## Install Core API
You'll need to install the `coreapi` package in order to add schema support
for REST framework.
pip install coreapi
## Internal schema representation
REST framework uses [Core API][coreapi] in order to model schema information in REST framework uses [Core API][coreapi] in order to model schema information in
a format-independent representation. This information can then be rendered a format-independent representation. This information can then be rendered
@ -68,9 +75,34 @@ has to be rendered into the actual bytes that are used in the response.
REST framework includes a renderer class for handling this media type, which REST framework includes a renderer class for handling this media type, which
is available as `renderers.CoreJSONRenderer`. is available as `renderers.CoreJSONRenderer`.
### Alternate schema formats
Other schema formats such as [Open API][open-api] ("Swagger"), Other schema formats such as [Open API][open-api] ("Swagger"),
[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can [JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can also
also be supported by implementing a custom renderer class. be supported by implementing a custom renderer class that handles converting a
`Document` instance into a bytestring representation.
If there is a Core API codec package that supports encoding into the format you
want to use then implementing the renderer class can be done by using the codec.
#### Example
For example, the `openapi_codec` package provides support for encoding or decoding
to the Open API ("Swagger") format:
from rest_framework import renderers
from openapi_codec import OpenAPICodec
class SwaggerRenderer(renderers.BaseRenderer):
media_type = 'application/openapi+json'
format = 'swagger'
def render(self, data, media_type=None, renderer_context=None):
codec = OpenAPICodec()
return codec.dump(data)
## Schemas vs Hypermedia ## Schemas vs Hypermedia
@ -89,18 +121,121 @@ document, detailing both the current state and the available interactions.
Further information and support on building Hypermedia APIs with REST framework Further information and support on building Hypermedia APIs with REST framework
is planned for a future version. is planned for a future version.
--- ---
# Adding a schema # Creating a schema
You'll need to install the `coreapi` package in order to add schema support
for REST framework.
pip install coreapi
REST framework includes functionality for auto-generating a schema, REST framework includes functionality for auto-generating a schema,
or allows you to specify one explicitly. There are a few different ways to or allows you to specify one explicitly.
add a schema to your API, depending on exactly what you need.
## Manual Schema Specification
To manually specify a schema you create a Core API `Document`, similar to the
example above.
schema = coreapi.Document(
title='Flight Search API',
content={
...
}
)
## Automatic Schema Generation
Automatic schema generation is provided by the `SchemaGenerator` class.
`SchemaGenerator` processes a list of routed URL pattterns and compiles the
appropriately structured Core API Document.
Basic usage is just to provide the title for your schema and call
`get_schema()`:
generator = schemas.SchemaGenerator(title='Flight Search API')
schema = generator.get_schema()
### Per-View Schema Customisation
By default, view introspection is performed by an `AutoSchema` instance
accessible via the `schema` attribute on `APIView`. This provides the
appropriate Core API `Link` object for the view, request method and path:
auto_schema = view.schema
coreapi_link = auto_schema.get_link(...)
(In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for
each view, allowed method and path.)
To customise the `Link` generation you may:
* Instantiate `AutoSchema` on your view with the `manual_fields` kwarg:
from rest_framework.views import APIView
from rest_framework.schemas import AutoSchema
class CustomView(APIView):
...
schema = AutoSchema(
manual_fields=[
coreapi.Field("extra_field", ...),
]
)
This allows extension for the most common case without subclassing.
* Provide an `AutoSchema` subclass with more complex customisation:
from rest_framework.views import APIView
from rest_framework.schemas import AutoSchema
class CustomSchema(AutoSchema):
def get_link(...):
# Implemet custom introspection here (or in other sub-methods)
class CustomView(APIView):
...
schema = CustomSchema()
This provides complete control over view introspection.
* Instantiate `ManualSchema` on your view, providing the Core API `Fields` for
the view explicitly:
from rest_framework.views import APIView
from rest_framework.schemas import ManualSchema
class CustomView(APIView):
...
schema = ManualSchema(fields=[
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
])
This allows manually specifying the schema for some views whilst maintaining
automatic generation elsewhere.
---
**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and
`ManualSchema` descriptors see the [API Reference below](#api-reference).
---
# Adding a schema view
There are a few different ways to add a schema view to your API, depending on
exactly what you need.
## The get_schema_view shortcut ## The get_schema_view shortcut
@ -342,38 +477,12 @@ A generic viewset with sections in the class docstring, using multi-line style.
--- ---
# Alternate schema formats
In order to support an alternate schema format, you need to implement a custom renderer
class that handles converting a `Document` instance into a bytestring representation.
If there is a Core API codec package that supports encoding into the format you
want to use then implementing the renderer class can be done by using the codec.
## Example
For example, the `openapi_codec` package provides support for encoding or decoding
to the Open API ("Swagger") format:
from rest_framework import renderers
from openapi_codec import OpenAPICodec
class SwaggerRenderer(renderers.BaseRenderer):
media_type = 'application/openapi+json'
format = 'swagger'
def render(self, data, media_type=None, renderer_context=None):
codec = OpenAPICodec()
return codec.dump(data)
---
# API Reference # API Reference
## SchemaGenerator ## SchemaGenerator
A class that deals with introspecting your API views, which can be used to A class that walks a list of routed URL patterns, requests the schema for each view,
generate a schema. and collates the resulting CoreAPI Document.
Typically you'll instantiate `SchemaGenerator` with a single argument, like so: Typically you'll instantiate `SchemaGenerator` with a single argument, like so:
@ -406,39 +515,108 @@ Return a nested dictionary containing all the links that should be included in t
This is a good point to override if you want to modify the resulting structure of the generated schema, This is a good point to override if you want to modify the resulting structure of the generated schema,
as you can build a new dictionary with a different layout. as you can build a new dictionary with a different layout.
### get_link(self, path, method, view)
## AutoSchema
A class that deals with introspection of individual views for schema generation.
`AutoSchema` is attached to `APIView` via the `schema` attribute.
The `AutoSchema` constructor takes a single keyword argument `manual_fields`.
**`manual_fields`**: a `list` of `coreapi.Field` instances that will be added to
the generated fields. Generated fields with a matching `name` will be overwritten.
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
For more advanced customisation subclass `AutoSchema` to customise schema generation.
class CustomViewSchema(AutoSchema):
"""
Overrides `get_link()` to provide Custom Behavior X
"""
def get_link(self, path, method, base_url):
link = super().get_link(path, method, base_url)
# Do something to customize link here...
return link
class MyView(APIView):
schema = CustomViewSchema()
The following methods are available to override.
### get_link(self, path, method, base_url)
Returns a `coreapi.Link` instance corresponding to the given view. Returns a `coreapi.Link` instance corresponding to the given view.
This is the main entry point.
You can override this if you need to provide custom behaviors for particular views. You can override this if you need to provide custom behaviors for particular views.
### get_description(self, path, method, view) ### get_description(self, path, method)
Returns a string to use as the link description. By default this is based on the Returns a string to use as the link description. By default this is based on the
view docstring as described in the "Schemas as Documentation" section above. view docstring as described in the "Schemas as Documentation" section above.
### get_encoding(self, path, method, view) ### get_encoding(self, path, method)
Returns a string to indicate the encoding for any request body, when interacting Returns a string to indicate the encoding for any request body, when interacting
with the given view. Eg. `'application/json'`. May return a blank string for views with the given view. Eg. `'application/json'`. May return a blank string for views
that do not expect a request body. that do not expect a request body.
### get_path_fields(self, path, method, view): ### get_path_fields(self, path, method):
Return a list of `coreapi.Link()` instances. One for each path parameter in the URL. Return a list of `coreapi.Link()` instances. One for each path parameter in the URL.
### get_serializer_fields(self, path, method, view) ### get_serializer_fields(self, path, method)
Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view. Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view.
### get_pagination_fields(self, path, method, view ### get_pagination_fields(self, path, method)
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view. Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view.
### get_filter_fields(self, path, method, view) ### get_filter_fields(self, path, method)
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view.
## ManualSchema
Allows manually providing a list of `coreapi.Field` instances for the schema,
plus an optional description.
class MyView(APIView):
schema = ManualSchema(fields=[
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
]
)
The `ManualSchema` constructor takes two arguments:
**`fields`**: A list of `coreapi.Field` instances. Required.
**`description`**: A string description. Optional.
--- ---
## Core API ## Core API

View File

@ -184,6 +184,28 @@ The available decorators are:
Each of these decorators takes a single argument which must be a list or tuple of classes. Each of these decorators takes a single argument which must be a list or tuple of classes.
## View schema decorator
To override the default schema generation for function based views you may use
the `@schema` decorator. This must come *after* (below) the `@api_view`
decorator. For example:
from rest_framework.decorators import api_view, schema
from rest_framework.schemas import AutoSchema
class CustomAutoSchema(AutoSchema):
def get_link(self, path, method, base_url):
# override view introspection here...
@api_view(['GET'])
@schema(CustomAutoSchema())
def view(request):
return Response({"message": "Hello for today! See you tomorrow!"})
This decorator takes a single `AutoSchema` instance, an `AutoSchema` subclass
instance or `ManualSchema` instance as described in the [Schemas documentation][schemas],
[cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
[settings]: settings.md [settings]: settings.md

View File

@ -72,6 +72,9 @@ def api_view(http_method_names=None, exclude_from_schema=False):
WrappedAPIView.permission_classes = getattr(func, 'permission_classes', WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes) APIView.permission_classes)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
WrappedAPIView.exclude_from_schema = exclude_from_schema WrappedAPIView.exclude_from_schema = exclude_from_schema
return WrappedAPIView.as_view() return WrappedAPIView.as_view()
return decorator return decorator
@ -112,6 +115,13 @@ def permission_classes(permission_classes):
return decorator return decorator
def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
def detail_route(methods=None, **kwargs): def detail_route(methods=None, **kwargs):
""" """
Used to mark a method on a ViewSet that should be routed for detail requests. Used to mark a method on a ViewSet that should be routed for detail requests.

View File

@ -26,7 +26,8 @@ from rest_framework import views
from rest_framework.compat import NoReverseMatch from rest_framework.compat import NoReverseMatch
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator, SchemaView from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns

View File

@ -0,0 +1,43 @@
"""
rest_framework.schemas
schemas:
__init__.py
generators.py # Top-down schema generation
inspectors.py # Per-endpoint view introspection
utils.py # Shared helper functions
views.py # Houses `SchemaView`, `APIView` subclass.
We expose a minimal "public" API directly from `schemas`. This covers the
basic use-cases:
from rest_framework.schemas import (
AutoSchema,
ManualSchema,
get_schema_view,
SchemaGenerator,
)
Other access should target the submodules directly
"""
from .generators import SchemaGenerator
from .inspectors import AutoSchema, ManualSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator):
"""
Return a schema view.
"""
# Avoid import cycle on APIView
from .views import SchemaView
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
)

View File

@ -1,86 +1,26 @@
import re """
generators.py # Top-down schema generation
See schemas.__init__.py for package overview.
"""
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings
from django.contrib.admindocs.views import simplify_regex from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import models
from django.http import Http404 from django.http import Http404
from django.utils import six from django.utils import six
from django.utils.encoding import force_text, smart_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions, renderers, serializers from rest_framework import exceptions
from rest_framework.compat import ( from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate, RegexURLPattern, RegexURLResolver, coreapi, coreschema
urlparse
) )
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from rest_framework.utils.model_meta import _get_pk from rest_framework.utils.model_meta import _get_pk
from rest_framework.views import APIView
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') from .utils import is_list_view
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
return coreschema.Array(
items=coreschema.String(),
title=title,
description=description
)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices.keys())),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices.keys()),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def common_path(paths): def common_path(paths):
@ -104,6 +44,8 @@ def is_api_view(callback):
""" """
Return `True` if the given view callback is a REST framework view/viewset. Return `True` if the given view callback is a REST framework view/viewset.
""" """
# Avoid import cycle on APIView
from rest_framework.views import APIView
cls = getattr(callback, 'cls', None) cls = getattr(callback, 'cls', None)
return (cls is not None) and issubclass(cls, APIView) return (cls is not None) and issubclass(cls, APIView)
@ -130,22 +72,6 @@ def is_custom_action(action):
]) ])
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
def endpoint_ordering(endpoint): def endpoint_ordering(endpoint):
path, method, callback = endpoint path, method, callback = endpoint
method_priority = { method_priority = {
@ -158,21 +84,7 @@ def endpoint_ordering(endpoint):
return (path, method_priority) return (path, method_priority)
def get_pk_description(model, model_field): class EndpointEnumerator(object):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class EndpointInspector(object):
""" """
A class to determine the available API endpoints that a project exposes. A class to determine the available API endpoints that a project exposes.
""" """
@ -265,7 +177,7 @@ class SchemaGenerator(object):
'patch': 'partial_update', 'patch': 'partial_update',
'delete': 'destroy', 'delete': 'destroy',
} }
endpoint_inspector_cls = EndpointInspector endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names. # Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation. # These give us names that are more suitable for the external representation.
@ -341,7 +253,7 @@ class SchemaGenerator(object):
for path, method, view in view_endpoints: for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view): if not self.has_view_permissions(path, method, view):
continue continue
link = self.get_link(path, method, view) link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):] subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view) keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link) insert_into(links, keys, link)
@ -433,197 +345,6 @@ class SchemaGenerator(object):
field_name = 'id' field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name) return path.replace('{pk}', '{%s}' % field_name)
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
fields = self.get_path_fields(path, method, view)
fields += self.get_serializer_fields(path, method, view)
fields += self.get_pagination_fields(path, method, view)
fields += self.get_filter_fields(path, method, view)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method, view)
else:
encoding = None
description = self.get_description(path, method, view)
if self.url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(self.url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method, view):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return formatting.dedent(smart_text(method_docstring))
description = view.get_view_description()
lines = [line.strip() for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
header = getattr(view, 'action', method.lower())
if header in sections:
return sections[header].strip()
if header in self.coerce_method_names:
if self.coerce_method_names[header] in sections:
return sections[self.coerce_method_names[header]].strip()
return sections[''].strip()
def get_encoding(self, path, method, view):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
# Core API supports the following request encodings over HTTP...
supported_media_types = set((
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
))
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
def get_path_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method, view):
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def get_filter_fields(self, path, method, view):
if not is_list_view(path, method, view):
return []
if not getattr(view, 'filter_backends', None):
return []
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_schema_fields(view)
return fields
# Method for generating the link layout.... # Method for generating the link layout....
def get_keys(self, subpath, method, view): def get_keys(self, subpath, method, view):
@ -669,45 +390,3 @@ class SchemaGenerator(object):
# Default action, eg "/users/", "/users/{pk}/" # Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action] return named_path_components + [action]
class SchemaView(APIView):
_ignore_model_permissions = True
exclude_from_schema = True
renderer_classes = None
schema_generator = None
public = False
def __init__(self, *args, **kwargs):
super(SchemaView, self).__init__(*args, **kwargs)
if self.renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes = [
renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer,
]
else:
self.renderer_classes = [renderers.CoreJSONRenderer]
def get(self, request, *args, **kwargs):
schema = self.schema_generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator):
"""
Return a schema view.
"""
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
)

View File

@ -0,0 +1,399 @@
"""
inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
from collections import OrderedDict
from django.db import models
from django.utils.encoding import force_text, smart_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
return coreschema.Array(
items=coreschema.String(),
title=title,
description=description
)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices.keys())),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices.keys()),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class ViewInspector(object):
"""
Descriptor class on APIView.
Provide subclass for per-view schema generation
"""
def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
This is how `view.schema` knows about `view`.
`__get__` is called when the descriptor is accessed on the owner.
(That will be when view.schema is called in our case.)
`owner` is always the owner class. (An APIView, or subclass for us.)
`instance` is the view instance or `None` if accessed from the class,
rather than an instance.
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
self.view = instance
return self
@property
def view(self):
"""View property."""
assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)"
return self._view
@view.setter
def view(self, value):
self._view = value
@view.deleter
def view(self):
self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view instrospection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
if self._manual_fields is not None:
by_name = {f.name: f for f in fields}
for f in self._manual_fields:
by_name[f.name] = f
fields = list(by_name.values())
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return formatting.dedent(smart_text(method_docstring))
description = view.get_view_description()
lines = [line.strip() for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
header = getattr(view, 'action', method.lower())
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def get_filter_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
if not getattr(view, 'filter_backends', None):
return []
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_schema_fields(view)
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = set((
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
))
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description=''):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `descripton`: String description for view. Optional.
"""
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(base_url, path),
action=method.lower(),
encoding=None,
fields=self._fields,
description=self._description
)
return self._link

View File

@ -0,0 +1,21 @@
"""
utils.py # Shared helper functions
See schemas.__init__.py for package overview.
"""
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True

View File

@ -0,0 +1,34 @@
"""
views.py # Houses `SchemaView`, `APIView` subclass.
See schemas.__init__.py for package overview.
"""
from rest_framework import exceptions, renderers
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
class SchemaView(APIView):
_ignore_model_permissions = True
exclude_from_schema = True
renderer_classes = None
schema_generator = None
public = False
def __init__(self, *args, **kwargs):
super(SchemaView, self).__init__(*args, **kwargs)
if self.renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes = [
renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer,
]
else:
self.renderer_classes = [renderers.CoreJSONRenderer]
def get(self, request, *args, **kwargs):
schema = self.schema_generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)

View File

@ -19,6 +19,7 @@ from rest_framework import exceptions, status
from rest_framework.compat import set_rollback from rest_framework.compat import set_rollback
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting from rest_framework.utils import formatting
@ -113,6 +114,7 @@ class APIView(View):
# Mark the view as being included or excluded from schema generation. # Mark the view as being included or excluded from schema generation.
exclude_from_schema = False exclude_from_schema = False
schema = AutoSchema()
@classmethod @classmethod
def as_view(cls, **initkwargs): def as_view(cls, **initkwargs):

View File

@ -6,12 +6,13 @@ from rest_framework import status
from rest_framework.authentication import BasicAuthentication from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import ( from rest_framework.decorators import (
api_view, authentication_classes, parser_classes, permission_classes, api_view, authentication_classes, parser_classes, permission_classes,
renderer_classes, throttle_classes renderer_classes, schema, throttle_classes
) )
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
@ -151,3 +152,17 @@ class DecoratorTestCase(TestCase):
response = view(request) response = view(request)
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self):
"""
Checks CustomSchema class is set on view
"""
class CustomSchema(AutoSchema):
pass
@api_view(['GET'])
@schema(CustomSchema())
def view(request):
return Response({})
assert isinstance(view.cls.schema, CustomSchema)

View File

@ -1,5 +1,6 @@
import unittest import unittest
import pytest
from django.conf.urls import include, url from django.conf.urls import include, url
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
@ -10,7 +11,9 @@ from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator, get_schema_view from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
)
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
@ -496,3 +499,81 @@ class Test4605Regression(TestCase):
'/auth/convert-token/' '/auth/convert-token/'
]) ])
assert prefix == '/' assert prefix == '/'
class TestDescriptor(TestCase):
def test_apiview_schema_descriptor(self):
view = APIView()
assert hasattr(view, 'schema')
assert isinstance(view.schema, AutoSchema)
def test_get_link_requires_instance(self):
descriptor = APIView.schema # Accessed from class
with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
def test_manual_fields(self):
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
view = CustomView()
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
def test_view_with_manual_schema(self):
path = '/example'
method = 'get'
base_url = None
fields = [
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"third_field",
required=True,
location="path",
schema=coreschema.String()
),
]
description = "A test endpoint"
class CustomView(APIView):
"""
ManualSchema takes list of fields for endpoint.
- Provides url and action, which are always dynamic
"""
schema = ManualSchema(fields, description)
expected = coreapi.Link(
url=path,
action=method,
fields=fields,
description=description
)
view = CustomView()
link = view.schema.get_link(path, method, base_url)
assert link == expected