mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-03 05:04:31 +03:00
Refactor schema generation to allow per-view customisation (#5354)
* Initial Refactor Step
* Add descriptor class
* call from generator
* proxy back to generator for implementation.
* Move `get_link` to descriptor
* Move `get_description` to descriptor
* Remove need for generator in get_description
* Move get_path_fields to descriptor
* Move `get_serializer_fields` to descriptor
* Move `get_pagination_fields` to descriptor
* Move `get_filter_fields` to descriptor
* Move `get_encoding` to descriptor.
* Pass just `url` from SchemaGenerator to descriptor
* Make `view` a property
Encapsulates check for a view instance.
* Adjust API Reference docs
* Add `ManualSchema` class
* Refactor to `ViewInspector` plus `AutoSchema`
The interface then is **just** `get_link()`
* Add `manual_fields` kwarg to AutoSchema
* Add schema decorator for FBVs
* Adjust comments
* Docs: Provide full params in example
Ref feedback b52e372f8f (r137254795)
* Add docstring for ViewInstpector.__get__ descriptor method.
Ref https://github.com/encode/django-rest-framework/pull/5354#discussion_r137265022
* Make `schemas` a package.
* Split generators, inspectors, views.
* Adjust imports
* Rename to EndpointEnumerator
* Adjust ManualSchema to take `fields`
… and `description`.
Allows `url` and `action` to remain dynamic
* Add package/module docstrings
This commit is contained in:
parent
5ea810d526
commit
d54df8c438
|
@ -10,7 +10,14 @@ API schemas are a useful tool that allow for a range of use cases, including
|
|||
generating reference documentation, or driving dynamic client libraries that
|
||||
can interact with your API.
|
||||
|
||||
## Representing schemas internally
|
||||
## Install Core API
|
||||
|
||||
You'll need to install the `coreapi` package in order to add schema support
|
||||
for REST framework.
|
||||
|
||||
pip install coreapi
|
||||
|
||||
## Internal schema representation
|
||||
|
||||
REST framework uses [Core API][coreapi] in order to model schema information in
|
||||
a format-independent representation. This information can then be rendered
|
||||
|
@ -68,9 +75,34 @@ has to be rendered into the actual bytes that are used in the response.
|
|||
REST framework includes a renderer class for handling this media type, which
|
||||
is available as `renderers.CoreJSONRenderer`.
|
||||
|
||||
### Alternate schema formats
|
||||
|
||||
Other schema formats such as [Open API][open-api] ("Swagger"),
|
||||
[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can
|
||||
also be supported by implementing a custom renderer class.
|
||||
[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can also
|
||||
be supported by implementing a custom renderer class that handles converting a
|
||||
`Document` instance into a bytestring representation.
|
||||
|
||||
If there is a Core API codec package that supports encoding into the format you
|
||||
want to use then implementing the renderer class can be done by using the codec.
|
||||
|
||||
#### Example
|
||||
|
||||
For example, the `openapi_codec` package provides support for encoding or decoding
|
||||
to the Open API ("Swagger") format:
|
||||
|
||||
from rest_framework import renderers
|
||||
from openapi_codec import OpenAPICodec
|
||||
|
||||
class SwaggerRenderer(renderers.BaseRenderer):
|
||||
media_type = 'application/openapi+json'
|
||||
format = 'swagger'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
codec = OpenAPICodec()
|
||||
return codec.dump(data)
|
||||
|
||||
|
||||
|
||||
|
||||
## Schemas vs Hypermedia
|
||||
|
||||
|
@ -89,18 +121,121 @@ document, detailing both the current state and the available interactions.
|
|||
Further information and support on building Hypermedia APIs with REST framework
|
||||
is planned for a future version.
|
||||
|
||||
|
||||
---
|
||||
|
||||
# Adding a schema
|
||||
|
||||
You'll need to install the `coreapi` package in order to add schema support
|
||||
for REST framework.
|
||||
|
||||
pip install coreapi
|
||||
# Creating a schema
|
||||
|
||||
REST framework includes functionality for auto-generating a schema,
|
||||
or allows you to specify one explicitly. There are a few different ways to
|
||||
add a schema to your API, depending on exactly what you need.
|
||||
or allows you to specify one explicitly.
|
||||
|
||||
## Manual Schema Specification
|
||||
|
||||
To manually specify a schema you create a Core API `Document`, similar to the
|
||||
example above.
|
||||
|
||||
schema = coreapi.Document(
|
||||
title='Flight Search API',
|
||||
content={
|
||||
...
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
## Automatic Schema Generation
|
||||
|
||||
Automatic schema generation is provided by the `SchemaGenerator` class.
|
||||
|
||||
`SchemaGenerator` processes a list of routed URL pattterns and compiles the
|
||||
appropriately structured Core API Document.
|
||||
|
||||
Basic usage is just to provide the title for your schema and call
|
||||
`get_schema()`:
|
||||
|
||||
generator = schemas.SchemaGenerator(title='Flight Search API')
|
||||
schema = generator.get_schema()
|
||||
|
||||
### Per-View Schema Customisation
|
||||
|
||||
By default, view introspection is performed by an `AutoSchema` instance
|
||||
accessible via the `schema` attribute on `APIView`. This provides the
|
||||
appropriate Core API `Link` object for the view, request method and path:
|
||||
|
||||
auto_schema = view.schema
|
||||
coreapi_link = auto_schema.get_link(...)
|
||||
|
||||
(In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for
|
||||
each view, allowed method and path.)
|
||||
|
||||
To customise the `Link` generation you may:
|
||||
|
||||
* Instantiate `AutoSchema` on your view with the `manual_fields` kwarg:
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.schemas import AutoSchema
|
||||
|
||||
class CustomView(APIView):
|
||||
...
|
||||
schema = AutoSchema(
|
||||
manual_fields=[
|
||||
coreapi.Field("extra_field", ...),
|
||||
]
|
||||
)
|
||||
|
||||
This allows extension for the most common case without subclassing.
|
||||
|
||||
* Provide an `AutoSchema` subclass with more complex customisation:
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.schemas import AutoSchema
|
||||
|
||||
class CustomSchema(AutoSchema):
|
||||
def get_link(...):
|
||||
# Implemet custom introspection here (or in other sub-methods)
|
||||
|
||||
class CustomView(APIView):
|
||||
...
|
||||
schema = CustomSchema()
|
||||
|
||||
This provides complete control over view introspection.
|
||||
|
||||
* Instantiate `ManualSchema` on your view, providing the Core API `Fields` for
|
||||
the view explicitly:
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.schemas import ManualSchema
|
||||
|
||||
class CustomView(APIView):
|
||||
...
|
||||
schema = ManualSchema(fields=[
|
||||
coreapi.Field(
|
||||
"first_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
coreapi.Field(
|
||||
"second_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
])
|
||||
|
||||
This allows manually specifying the schema for some views whilst maintaining
|
||||
automatic generation elsewhere.
|
||||
|
||||
---
|
||||
|
||||
**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and
|
||||
`ManualSchema` descriptors see the [API Reference below](#api-reference).
|
||||
|
||||
---
|
||||
|
||||
# Adding a schema view
|
||||
|
||||
There are a few different ways to add a schema view to your API, depending on
|
||||
exactly what you need.
|
||||
|
||||
## The get_schema_view shortcut
|
||||
|
||||
|
@ -342,38 +477,12 @@ A generic viewset with sections in the class docstring, using multi-line style.
|
|||
|
||||
---
|
||||
|
||||
# Alternate schema formats
|
||||
|
||||
In order to support an alternate schema format, you need to implement a custom renderer
|
||||
class that handles converting a `Document` instance into a bytestring representation.
|
||||
|
||||
If there is a Core API codec package that supports encoding into the format you
|
||||
want to use then implementing the renderer class can be done by using the codec.
|
||||
|
||||
## Example
|
||||
|
||||
For example, the `openapi_codec` package provides support for encoding or decoding
|
||||
to the Open API ("Swagger") format:
|
||||
|
||||
from rest_framework import renderers
|
||||
from openapi_codec import OpenAPICodec
|
||||
|
||||
class SwaggerRenderer(renderers.BaseRenderer):
|
||||
media_type = 'application/openapi+json'
|
||||
format = 'swagger'
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
codec = OpenAPICodec()
|
||||
return codec.dump(data)
|
||||
|
||||
---
|
||||
|
||||
# API Reference
|
||||
|
||||
## SchemaGenerator
|
||||
|
||||
A class that deals with introspecting your API views, which can be used to
|
||||
generate a schema.
|
||||
A class that walks a list of routed URL patterns, requests the schema for each view,
|
||||
and collates the resulting CoreAPI Document.
|
||||
|
||||
Typically you'll instantiate `SchemaGenerator` with a single argument, like so:
|
||||
|
||||
|
@ -406,39 +515,108 @@ Return a nested dictionary containing all the links that should be included in t
|
|||
This is a good point to override if you want to modify the resulting structure of the generated schema,
|
||||
as you can build a new dictionary with a different layout.
|
||||
|
||||
### get_link(self, path, method, view)
|
||||
|
||||
## AutoSchema
|
||||
|
||||
A class that deals with introspection of individual views for schema generation.
|
||||
|
||||
`AutoSchema` is attached to `APIView` via the `schema` attribute.
|
||||
|
||||
The `AutoSchema` constructor takes a single keyword argument `manual_fields`.
|
||||
|
||||
**`manual_fields`**: a `list` of `coreapi.Field` instances that will be added to
|
||||
the generated fields. Generated fields with a matching `name` will be overwritten.
|
||||
|
||||
class CustomView(APIView):
|
||||
schema = AutoSchema(manual_fields=[
|
||||
coreapi.Field(
|
||||
"my_extra_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
])
|
||||
|
||||
For more advanced customisation subclass `AutoSchema` to customise schema generation.
|
||||
|
||||
class CustomViewSchema(AutoSchema):
|
||||
"""
|
||||
Overrides `get_link()` to provide Custom Behavior X
|
||||
"""
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
link = super().get_link(path, method, base_url)
|
||||
# Do something to customize link here...
|
||||
return link
|
||||
|
||||
class MyView(APIView):
|
||||
schema = CustomViewSchema()
|
||||
|
||||
The following methods are available to override.
|
||||
|
||||
### get_link(self, path, method, base_url)
|
||||
|
||||
Returns a `coreapi.Link` instance corresponding to the given view.
|
||||
|
||||
This is the main entry point.
|
||||
You can override this if you need to provide custom behaviors for particular views.
|
||||
|
||||
### get_description(self, path, method, view)
|
||||
### get_description(self, path, method)
|
||||
|
||||
Returns a string to use as the link description. By default this is based on the
|
||||
view docstring as described in the "Schemas as Documentation" section above.
|
||||
|
||||
### get_encoding(self, path, method, view)
|
||||
### get_encoding(self, path, method)
|
||||
|
||||
Returns a string to indicate the encoding for any request body, when interacting
|
||||
with the given view. Eg. `'application/json'`. May return a blank string for views
|
||||
that do not expect a request body.
|
||||
|
||||
### get_path_fields(self, path, method, view):
|
||||
### get_path_fields(self, path, method):
|
||||
|
||||
Return a list of `coreapi.Link()` instances. One for each path parameter in the URL.
|
||||
|
||||
### get_serializer_fields(self, path, method, view)
|
||||
### get_serializer_fields(self, path, method)
|
||||
|
||||
Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view.
|
||||
|
||||
### get_pagination_fields(self, path, method, view
|
||||
### get_pagination_fields(self, path, method)
|
||||
|
||||
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view.
|
||||
|
||||
### get_filter_fields(self, path, method, view)
|
||||
### get_filter_fields(self, path, method)
|
||||
|
||||
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view.
|
||||
|
||||
|
||||
## ManualSchema
|
||||
|
||||
Allows manually providing a list of `coreapi.Field` instances for the schema,
|
||||
plus an optional description.
|
||||
|
||||
class MyView(APIView):
|
||||
schema = ManualSchema(fields=[
|
||||
coreapi.Field(
|
||||
"first_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
coreapi.Field(
|
||||
"second_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
The `ManualSchema` constructor takes two arguments:
|
||||
|
||||
**`fields`**: A list of `coreapi.Field` instances. Required.
|
||||
|
||||
**`description`**: A string description. Optional.
|
||||
|
||||
---
|
||||
|
||||
## Core API
|
||||
|
|
|
@ -184,6 +184,28 @@ The available decorators are:
|
|||
|
||||
Each of these decorators takes a single argument which must be a list or tuple of classes.
|
||||
|
||||
|
||||
## View schema decorator
|
||||
|
||||
To override the default schema generation for function based views you may use
|
||||
the `@schema` decorator. This must come *after* (below) the `@api_view`
|
||||
decorator. For example:
|
||||
|
||||
from rest_framework.decorators import api_view, schema
|
||||
from rest_framework.schemas import AutoSchema
|
||||
|
||||
class CustomAutoSchema(AutoSchema):
|
||||
def get_link(self, path, method, base_url):
|
||||
# override view introspection here...
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(CustomAutoSchema())
|
||||
def view(request):
|
||||
return Response({"message": "Hello for today! See you tomorrow!"})
|
||||
|
||||
This decorator takes a single `AutoSchema` instance, an `AutoSchema` subclass
|
||||
instance or `ManualSchema` instance as described in the [Schemas documentation][schemas],
|
||||
|
||||
[cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
|
||||
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
|
||||
[settings]: settings.md
|
||||
|
|
|
@ -72,6 +72,9 @@ def api_view(http_method_names=None, exclude_from_schema=False):
|
|||
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
|
||||
APIView.permission_classes)
|
||||
|
||||
WrappedAPIView.schema = getattr(func, 'schema',
|
||||
APIView.schema)
|
||||
|
||||
WrappedAPIView.exclude_from_schema = exclude_from_schema
|
||||
return WrappedAPIView.as_view()
|
||||
return decorator
|
||||
|
@ -112,6 +115,13 @@ def permission_classes(permission_classes):
|
|||
return decorator
|
||||
|
||||
|
||||
def schema(view_inspector):
|
||||
def decorator(func):
|
||||
func.schema = view_inspector
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def detail_route(methods=None, **kwargs):
|
||||
"""
|
||||
Used to mark a method on a ViewSet that should be routed for detail requests.
|
||||
|
|
|
@ -26,7 +26,8 @@ from rest_framework import views
|
|||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator, SchemaView
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
from rest_framework.schemas.views import SchemaView
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
|
|
43
rest_framework/schemas/__init__.py
Normal file
43
rest_framework/schemas/__init__.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
"""
|
||||
rest_framework.schemas
|
||||
|
||||
schemas:
|
||||
__init__.py
|
||||
generators.py # Top-down schema generation
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
utils.py # Shared helper functions
|
||||
views.py # Houses `SchemaView`, `APIView` subclass.
|
||||
|
||||
We expose a minimal "public" API directly from `schemas`. This covers the
|
||||
basic use-cases:
|
||||
|
||||
from rest_framework.schemas import (
|
||||
AutoSchema,
|
||||
ManualSchema,
|
||||
get_schema_view,
|
||||
SchemaGenerator,
|
||||
)
|
||||
|
||||
Other access should target the submodules directly
|
||||
"""
|
||||
from .generators import SchemaGenerator
|
||||
from .inspectors import AutoSchema, ManualSchema # noqa
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=SchemaGenerator):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
# Avoid import cycle on APIView
|
||||
from .views import SchemaView
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns,
|
||||
)
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
)
|
|
@ -1,86 +1,26 @@
|
|||
import re
|
||||
"""
|
||||
generators.py # Top-down schema generation
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db import models
|
||||
from django.http import Http404
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework import exceptions, renderers, serializers
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.compat import (
|
||||
RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate,
|
||||
urlparse
|
||||
RegexURLPattern, RegexURLResolver, coreapi, coreschema
|
||||
)
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
from rest_framework.views import APIView
|
||||
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.String(),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices.keys())),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices.keys()),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
return coreschema.String(title=title, description=description)
|
||||
from .utils import is_list_view
|
||||
|
||||
|
||||
def common_path(paths):
|
||||
|
@ -104,6 +44,8 @@ def is_api_view(callback):
|
|||
"""
|
||||
Return `True` if the given view callback is a REST framework view/viewset.
|
||||
"""
|
||||
# Avoid import cycle on APIView
|
||||
from rest_framework.views import APIView
|
||||
cls = getattr(callback, 'cls', None)
|
||||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
@ -130,22 +72,6 @@ def is_custom_action(action):
|
|||
])
|
||||
|
||||
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have an explicitly defined action, which we can inspect.
|
||||
return view.action == 'list'
|
||||
|
||||
if method.lower() != 'get':
|
||||
return False
|
||||
path_components = path.strip('/').split('/')
|
||||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
|
@ -158,21 +84,7 @@ def endpoint_ordering(endpoint):
|
|||
return (path, method_priority)
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
||||
|
||||
class EndpointInspector(object):
|
||||
class EndpointEnumerator(object):
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
|
@ -265,7 +177,7 @@ class SchemaGenerator(object):
|
|||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
endpoint_inspector_cls = EndpointInspector
|
||||
endpoint_inspector_cls = EndpointEnumerator
|
||||
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
|
@ -341,7 +253,7 @@ class SchemaGenerator(object):
|
|||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = self.get_link(path, method, view)
|
||||
link = view.schema.get_link(path, method, base_url=self.url)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
|
@ -433,197 +345,6 @@ class SchemaGenerator(object):
|
|||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
# Methods for generating each individual `Link` instance...
|
||||
|
||||
def get_link(self, path, method, view):
|
||||
"""
|
||||
Return a `coreapi.Link` instance for the given endpoint.
|
||||
"""
|
||||
fields = self.get_path_fields(path, method, view)
|
||||
fields += self.get_serializer_fields(path, method, view)
|
||||
fields += self.get_pagination_fields(path, method, view)
|
||||
fields += self.get_filter_fields(path, method, view)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method, view)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method, view)
|
||||
|
||||
if self.url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=urlparse.urljoin(self.url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method, view):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return formatting.dedent(smart_text(method_docstring))
|
||||
|
||||
description = view.get_view_description()
|
||||
lines = [line.strip() for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
header = getattr(view, 'action', method.lower())
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in self.coerce_method_names:
|
||||
if self.coerce_method_names[header] in sections:
|
||||
return sections[self.coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_encoding(self, path, method, view):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = set((
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
))
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
def get_path_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method, view):
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def get_filter_fields(self, path, method, view):
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if not getattr(view, 'filter_backends', None):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(view)
|
||||
return fields
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
|
@ -669,45 +390,3 @@ class SchemaGenerator(object):
|
|||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = None
|
||||
schema_generator = None
|
||||
public = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SchemaView, self).__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes = [
|
||||
renderers.CoreJSONRenderer,
|
||||
renderers.BrowsableAPIRenderer,
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [renderers.CoreJSONRenderer]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = self.schema_generator.get_schema(request, self.public)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
|
||||
def get_schema_view(
|
||||
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
|
||||
public=False, patterns=None, generator_class=SchemaGenerator):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
generator = generator_class(
|
||||
title=title, url=url, description=description,
|
||||
urlconf=urlconf, patterns=patterns,
|
||||
)
|
||||
return SchemaView.as_view(
|
||||
renderer_classes=renderer_classes,
|
||||
schema_generator=generator,
|
||||
public=public,
|
||||
)
|
399
rest_framework/schemas/inspectors.py
Normal file
399
rest_framework/schemas/inspectors.py
Normal file
|
@ -0,0 +1,399 @@
|
|||
"""
|
||||
inspectors.py # Per-endpoint view introspection
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
from .utils import is_list_view
|
||||
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
|
||||
def field_to_schema(field):
|
||||
title = force_text(field.label) if field.label else ''
|
||||
description = force_text(field.help_text) if field.help_text else ''
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = field_to_schema(field.child)
|
||||
return coreschema.Array(
|
||||
items=child_schema,
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
return coreschema.Object(
|
||||
properties=OrderedDict([
|
||||
(key, field_to_schema(value))
|
||||
for key, value
|
||||
in field.fields.items()
|
||||
]),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ManyRelatedField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.String(),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.RelatedField):
|
||||
return coreschema.String(title=title, description=description)
|
||||
elif isinstance(field, serializers.MultipleChoiceField):
|
||||
return coreschema.Array(
|
||||
items=coreschema.Enum(enum=list(field.choices.keys())),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.ChoiceField):
|
||||
return coreschema.Enum(
|
||||
enum=list(field.choices.keys()),
|
||||
title=title,
|
||||
description=description
|
||||
)
|
||||
elif isinstance(field, serializers.BooleanField):
|
||||
return coreschema.Boolean(title=title, description=description)
|
||||
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
|
||||
return coreschema.Number(title=title, description=description)
|
||||
elif isinstance(field, serializers.IntegerField):
|
||||
return coreschema.Integer(title=title, description=description)
|
||||
|
||||
if field.style.get('base_template') == 'textarea.html':
|
||||
return coreschema.String(
|
||||
title=title,
|
||||
description=description,
|
||||
format='textarea'
|
||||
)
|
||||
return coreschema.String(title=title, description=description)
|
||||
|
||||
|
||||
def get_pk_description(model, model_field):
|
||||
if isinstance(model_field, models.AutoField):
|
||||
value_type = _('unique integer value')
|
||||
elif isinstance(model_field, models.UUIDField):
|
||||
value_type = _('UUID string')
|
||||
else:
|
||||
value_type = _('unique value')
|
||||
|
||||
return _('A {value_type} identifying this {name}.').format(
|
||||
value_type=value_type,
|
||||
name=model._meta.verbose_name,
|
||||
)
|
||||
|
||||
|
||||
class ViewInspector(object):
|
||||
"""
|
||||
Descriptor class on APIView.
|
||||
|
||||
Provide subclass for per-view schema generation
|
||||
"""
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Enables `ViewInspector` as a Python _Descriptor_.
|
||||
|
||||
This is how `view.schema` knows about `view`.
|
||||
|
||||
`__get__` is called when the descriptor is accessed on the owner.
|
||||
(That will be when view.schema is called in our case.)
|
||||
|
||||
`owner` is always the owner class. (An APIView, or subclass for us.)
|
||||
`instance` is the view instance or `None` if accessed from the class,
|
||||
rather than an instance.
|
||||
|
||||
See: https://docs.python.org/3/howto/descriptor.html for info on
|
||||
descriptor usage.
|
||||
"""
|
||||
self.view = instance
|
||||
return self
|
||||
|
||||
@property
|
||||
def view(self):
|
||||
"""View property."""
|
||||
assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)"
|
||||
return self._view
|
||||
|
||||
@view.setter
|
||||
def view(self, value):
|
||||
self._view = value
|
||||
|
||||
@view.deleter
|
||||
def view(self):
|
||||
self._view = None
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
"""
|
||||
Generate `coreapi.Link` for self.view, path and method.
|
||||
|
||||
This is the main _public_ access point.
|
||||
|
||||
Parameters:
|
||||
|
||||
* path: Route path for view from URLConf.
|
||||
* method: The HTTP request method.
|
||||
* base_url: The project "mount point" as given to SchemaGenerator
|
||||
"""
|
||||
raise NotImplementedError(".get_link() must be overridden.")
|
||||
|
||||
|
||||
class AutoSchema(ViewInspector):
|
||||
"""
|
||||
Default inspector for APIView
|
||||
|
||||
Responsible for per-view instrospection and schema generation.
|
||||
"""
|
||||
def __init__(self, manual_fields=None):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `manual_fields`: list of `coreapi.Field` instances that
|
||||
will be added to auto-generated fields, overwriting on `Field.name`
|
||||
"""
|
||||
|
||||
self._manual_fields = manual_fields
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
fields = self.get_path_fields(path, method)
|
||||
fields += self.get_serializer_fields(path, method)
|
||||
fields += self.get_pagination_fields(path, method)
|
||||
fields += self.get_filter_fields(path, method)
|
||||
|
||||
if self._manual_fields is not None:
|
||||
by_name = {f.name: f for f in fields}
|
||||
for f in self._manual_fields:
|
||||
by_name[f.name] = f
|
||||
fields = list(by_name.values())
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method)
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=urlparse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_description(self, path, method):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return formatting.dedent(smart_text(method_docstring))
|
||||
|
||||
description = view.get_view_description()
|
||||
lines = [line.strip() for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += '\n' + line
|
||||
|
||||
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
|
||||
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
header = getattr(view, 'action', method.lower())
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in coerce_method_names:
|
||||
if coerce_method_names[header] in sections:
|
||||
return sections[coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_path_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
"""
|
||||
view = self.view
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
fields = []
|
||||
|
||||
for variable in uritemplate.variables(path):
|
||||
title = ''
|
||||
description = ''
|
||||
schema_cls = coreschema.String
|
||||
kwargs = {}
|
||||
if model is not None:
|
||||
# Attempt to infer a field description if possible.
|
||||
try:
|
||||
model_field = model._meta.get_field(variable)
|
||||
except:
|
||||
model_field = None
|
||||
|
||||
if model_field is not None and model_field.verbose_name:
|
||||
title = force_text(model_field.verbose_name)
|
||||
|
||||
if model_field is not None and model_field.help_text:
|
||||
description = force_text(model_field.help_text)
|
||||
elif model_field is not None and model_field.primary_key:
|
||||
description = get_pk_description(model, model_field)
|
||||
|
||||
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
|
||||
kwargs['pattern'] = view.lookup_value_regex
|
||||
elif isinstance(model_field, models.AutoField):
|
||||
schema_cls = coreschema.Integer
|
||||
|
||||
field = coreapi.Field(
|
||||
name=variable,
|
||||
location='path',
|
||||
required=True,
|
||||
schema=schema_cls(title=title, description=description, **kwargs)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
if method not in ('PUT', 'PATCH', 'POST'):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'get_serializer'):
|
||||
return []
|
||||
|
||||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
schema=coreschema.Array()
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in serializer.fields.values():
|
||||
if field.read_only or isinstance(field, serializers.HiddenField):
|
||||
continue
|
||||
|
||||
required = field.required and method != 'PATCH'
|
||||
field = coreapi.Field(
|
||||
name=field.field_name,
|
||||
location='form',
|
||||
required=required,
|
||||
schema=field_to_schema(field)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
pagination = getattr(view, 'pagination_class', None)
|
||||
if not pagination:
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def get_filter_fields(self, path, method):
|
||||
view = self.view
|
||||
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if not getattr(view, 'filter_backends', None):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in view.filter_backends:
|
||||
fields += filter_backend().get_schema_fields(view)
|
||||
return fields
|
||||
|
||||
def get_encoding(self, path, method):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
view = self.view
|
||||
|
||||
# Core API supports the following request encodings over HTTP...
|
||||
supported_media_types = set((
|
||||
'application/json',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data',
|
||||
))
|
||||
parser_classes = getattr(view, 'parser_classes', [])
|
||||
for parser_class in parser_classes:
|
||||
media_type = getattr(parser_class, 'media_type', None)
|
||||
if media_type in supported_media_types:
|
||||
return media_type
|
||||
# Raw binary uploads are supported with "application/octet-stream"
|
||||
if media_type == '*/*':
|
||||
return 'application/octet-stream'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ManualSchema(ViewInspector):
|
||||
"""
|
||||
Allows providing a list of coreapi.Fields,
|
||||
plus an optional description.
|
||||
"""
|
||||
def __init__(self, fields, description=''):
|
||||
"""
|
||||
Parameters:
|
||||
|
||||
* `fields`: list of `coreapi.Field` instances.
|
||||
* `descripton`: String description for view. Optional.
|
||||
"""
|
||||
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
|
||||
self._fields = fields
|
||||
self._description = description
|
||||
|
||||
def get_link(self, path, method, base_url):
|
||||
|
||||
if base_url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
return coreapi.Link(
|
||||
url=urlparse.urljoin(base_url, path),
|
||||
action=method.lower(),
|
||||
encoding=None,
|
||||
fields=self._fields,
|
||||
description=self._description
|
||||
)
|
||||
|
||||
return self._link
|
21
rest_framework/schemas/utils.py
Normal file
21
rest_framework/schemas/utils.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
"""
|
||||
utils.py # Shared helper functions
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
|
||||
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have an explicitly defined action, which we can inspect.
|
||||
return view.action == 'list'
|
||||
|
||||
if method.lower() != 'get':
|
||||
return False
|
||||
path_components = path.strip('/').split('/')
|
||||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
34
rest_framework/schemas/views.py
Normal file
34
rest_framework/schemas/views.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
views.py # Houses `SchemaView`, `APIView` subclass.
|
||||
|
||||
See schemas.__init__.py for package overview.
|
||||
"""
|
||||
from rest_framework import exceptions, renderers
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = None
|
||||
schema_generator = None
|
||||
public = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SchemaView, self).__init__(*args, **kwargs)
|
||||
if self.renderer_classes is None:
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
self.renderer_classes = [
|
||||
renderers.CoreJSONRenderer,
|
||||
renderers.BrowsableAPIRenderer,
|
||||
]
|
||||
else:
|
||||
self.renderer_classes = [renderers.CoreJSONRenderer]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = self.schema_generator.get_schema(request, self.public)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
|
@ -19,6 +19,7 @@ from rest_framework import exceptions, status
|
|||
from rest_framework.compat import set_rollback
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import AutoSchema
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
|
||||
|
@ -113,6 +114,7 @@ class APIView(View):
|
|||
|
||||
# Mark the view as being included or excluded from schema generation.
|
||||
exclude_from_schema = False
|
||||
schema = AutoSchema()
|
||||
|
||||
@classmethod
|
||||
def as_view(cls, **initkwargs):
|
||||
|
|
|
@ -6,12 +6,13 @@ from rest_framework import status
|
|||
from rest_framework.authentication import BasicAuthentication
|
||||
from rest_framework.decorators import (
|
||||
api_view, authentication_classes, parser_classes, permission_classes,
|
||||
renderer_classes, throttle_classes
|
||||
renderer_classes, schema, throttle_classes
|
||||
)
|
||||
from rest_framework.parsers import JSONParser
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.renderers import JSONRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.schemas import AutoSchema
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.throttling import UserRateThrottle
|
||||
from rest_framework.views import APIView
|
||||
|
@ -151,3 +152,17 @@ class DecoratorTestCase(TestCase):
|
|||
|
||||
response = view(request)
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
def test_schema(self):
|
||||
"""
|
||||
Checks CustomSchema class is set on view
|
||||
"""
|
||||
class CustomSchema(AutoSchema):
|
||||
pass
|
||||
|
||||
@api_view(['GET'])
|
||||
@schema(CustomSchema())
|
||||
def view(request):
|
||||
return Response({})
|
||||
|
||||
assert isinstance(view.cls.schema, CustomSchema)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
from django.conf.urls import include, url
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import Http404
|
||||
|
@ -10,7 +11,9 @@ from rest_framework.compat import coreapi, coreschema
|
|||
from rest_framework.decorators import detail_route, list_route
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.schemas import SchemaGenerator, get_schema_view
|
||||
from rest_framework.schemas import (
|
||||
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
|
||||
)
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
@ -496,3 +499,81 @@ class Test4605Regression(TestCase):
|
|||
'/auth/convert-token/'
|
||||
])
|
||||
assert prefix == '/'
|
||||
|
||||
|
||||
class TestDescriptor(TestCase):
|
||||
|
||||
def test_apiview_schema_descriptor(self):
|
||||
view = APIView()
|
||||
assert hasattr(view, 'schema')
|
||||
assert isinstance(view.schema, AutoSchema)
|
||||
|
||||
def test_get_link_requires_instance(self):
|
||||
descriptor = APIView.schema # Accessed from class
|
||||
with pytest.raises(AssertionError):
|
||||
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
|
||||
|
||||
def test_manual_fields(self):
|
||||
|
||||
class CustomView(APIView):
|
||||
schema = AutoSchema(manual_fields=[
|
||||
coreapi.Field(
|
||||
"my_extra_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
])
|
||||
|
||||
view = CustomView()
|
||||
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
|
||||
fields = link.fields
|
||||
|
||||
assert len(fields) == 2
|
||||
assert "my_extra_field" in [f.name for f in fields]
|
||||
|
||||
def test_view_with_manual_schema(self):
|
||||
|
||||
path = '/example'
|
||||
method = 'get'
|
||||
base_url = None
|
||||
|
||||
fields = [
|
||||
coreapi.Field(
|
||||
"first_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
coreapi.Field(
|
||||
"second_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
coreapi.Field(
|
||||
"third_field",
|
||||
required=True,
|
||||
location="path",
|
||||
schema=coreschema.String()
|
||||
),
|
||||
]
|
||||
description = "A test endpoint"
|
||||
|
||||
class CustomView(APIView):
|
||||
"""
|
||||
ManualSchema takes list of fields for endpoint.
|
||||
- Provides url and action, which are always dynamic
|
||||
"""
|
||||
schema = ManualSchema(fields, description)
|
||||
|
||||
expected = coreapi.Link(
|
||||
url=path,
|
||||
action=method,
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
view = CustomView()
|
||||
link = view.schema.get_link(path, method, base_url)
|
||||
assert link == expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user