Resolve conflicts with latest master branch

This commit is contained in:
fakepoet 2019-05-14 14:41:12 +08:00
commit 1fefc4e74a
32 changed files with 1726 additions and 781 deletions

View File

@ -5,9 +5,6 @@ matrix:
fast_finish: true fast_finish: true
include: include:
- { python: "3.4", env: DJANGO=1.11 }
- { python: "3.4", env: DJANGO=2.0 }
- { python: "3.5", env: DJANGO=1.11 } - { python: "3.5", env: DJANGO=1.11 }
- { python: "3.5", env: DJANGO=2.0 } - { python: "3.5", env: DJANGO=2.0 }
- { python: "3.5", env: DJANGO=2.1 } - { python: "3.5", env: DJANGO=2.1 }

View File

@ -53,7 +53,7 @@ There is a live example API for testing purposes, [available here][sandbox].
# Requirements # Requirements
* Python (3.4, 3.5, 3.6, 3.7) * Python (3.5, 3.6, 3.7)
* Django (1.11, 2.0, 2.1, 2.2) * Django (1.11, 2.0, 2.1, 2.2)
We **highly recommend** and only officially support the latest patch release of We **highly recommend** and only officially support the latest patch release of

View File

@ -448,9 +448,10 @@ Requires either the `Pillow` package or `PIL` package. The `Pillow` package is
A field class that validates a list of objects. A field class that validates a list of objects.
**Signature**: `ListField(child=<A_FIELD_INSTANCE>, min_length=None, max_length=None)` **Signature**: `ListField(child=<A_FIELD_INSTANCE>, allow_empty=True, min_length=None, max_length=None)`
- `child` - A field instance that should be used for validating the objects in the list. If this argument is not provided then objects in the list will not be validated. - `child` - A field instance that should be used for validating the objects in the list. If this argument is not provided then objects in the list will not be validated.
- `allow_empty` - Designates if empty lists are allowed.
- `min_length` - Validates that the list contains no fewer than this number of elements. - `min_length` - Validates that the list contains no fewer than this number of elements.
- `max_length` - Validates that the list contains no more than this number of elements. - `max_length` - Validates that the list contains no more than this number of elements.
@ -471,9 +472,10 @@ We can now reuse our custom `StringListField` class throughout our application,
A field class that validates a dictionary of objects. The keys in `DictField` are always assumed to be string values. A field class that validates a dictionary of objects. The keys in `DictField` are always assumed to be string values.
**Signature**: `DictField(child=<A_FIELD_INSTANCE>)` **Signature**: `DictField(child=<A_FIELD_INSTANCE>, allow_empty=True)`
- `child` - A field instance that should be used for validating the values in the dictionary. If this argument is not provided then values in the mapping will not be validated. - `child` - A field instance that should be used for validating the values in the dictionary. If this argument is not provided then values in the mapping will not be validated.
- `allow_empty` - Designates if empty dictionaries are allowed.
For example, to create a field that validates a mapping of strings to strings, you would write something like this: For example, to create a field that validates a mapping of strings to strings, you would write something like this:
@ -488,9 +490,10 @@ You can also use the declarative style, as with `ListField`. For example:
A preconfigured `DictField` that is compatible with Django's postgres `HStoreField`. A preconfigured `DictField` that is compatible with Django's postgres `HStoreField`.
**Signature**: `HStoreField(child=<A_FIELD_INSTANCE>)` **Signature**: `HStoreField(child=<A_FIELD_INSTANCE>, allow_empty=True)`
- `child` - A field instance that is used for validating the values in the dictionary. The default child field accepts both empty strings and null values. - `child` - A field instance that is used for validating the values in the dictionary. The default child field accepts both empty strings and null values.
- `allow_empty` - Designates if empty dictionaries are allowed.
Note that the child field **must** be an instance of `CharField`, as the hstore extension stores values as strings. Note that the child field **must** be an instance of `CharField`, as the hstore extension stores values as strings.

View File

@ -576,6 +576,8 @@ If you explicitly specify a relational field pointing to a
``ManyToManyField`` with a through model, be sure to set ``read_only`` ``ManyToManyField`` with a through model, be sure to set ``read_only``
to ``True``. to ``True``.
If you wish to represent [extra fields on a through model][django-intermediary-manytomany] then you may serialize the through model as [a nested object][dealing-with-nested-objects].
--- ---
# Third Party Packages # Third Party Packages
@ -596,3 +598,5 @@ The [rest-framework-generic-relations][drf-nested-relations] library provides re
[generic-relations]: https://docs.djangoproject.com/en/stable/ref/contrib/contenttypes/#id1 [generic-relations]: https://docs.djangoproject.com/en/stable/ref/contrib/contenttypes/#id1
[drf-nested-routers]: https://github.com/alanjds/drf-nested-routers [drf-nested-routers]: https://github.com/alanjds/drf-nested-routers
[drf-nested-relations]: https://github.com/Ian-Foote/rest-framework-generic-relations [drf-nested-relations]: https://github.com/Ian-Foote/rest-framework-generic-relations
[django-intermediary-manytomany]: https://docs.djangoproject.com/en/2.2/topics/db/models/#intermediary-manytomany
[dealing-with-nested-objects]: https://www.django-rest-framework.org/api-guide/serializers/#dealing-with-nested-objects

View File

@ -264,6 +264,7 @@ To submit new content, [open an issue][drf-create-issue] or [create a pull reque
* [djangorest-alchemy][djangorest-alchemy] - SQLAlchemy support for REST framework. * [djangorest-alchemy][djangorest-alchemy] - SQLAlchemy support for REST framework.
* [djangorestframework-datatables][djangorestframework-datatables] - Seamless integration between Django REST framework and [Datatables](https://datatables.net). * [djangorestframework-datatables][djangorestframework-datatables] - Seamless integration between Django REST framework and [Datatables](https://datatables.net).
* [django-rest-framework-condition][django-rest-framework-condition] - Decorators for managing HTTP cache headers for Django REST framework (ETag and Last-modified). * [django-rest-framework-condition][django-rest-framework-condition] - Decorators for managing HTTP cache headers for Django REST framework (ETag and Last-modified).
* [django-rest-witchcraft][django-rest-witchcraft] - Provides DRF integration with SQLAlchemy with SQLAlchemy model serializers/viewsets and a bunch of other goodies
[cite]: http://www.software-ecosystems.com/Software_Ecosystems/Ecosystems.html [cite]: http://www.software-ecosystems.com/Software_Ecosystems/Ecosystems.html
[cookiecutter]: https://github.com/jpadilla/cookiecutter-django-rest-framework [cookiecutter]: https://github.com/jpadilla/cookiecutter-django-rest-framework
@ -338,3 +339,4 @@ To submit new content, [open an issue][drf-create-issue] or [create a pull reque
[djangorest-alchemy]: https://github.com/dealertrack/djangorest-alchemy [djangorest-alchemy]: https://github.com/dealertrack/djangorest-alchemy
[djangorestframework-datatables]: https://github.com/izimobil/django-rest-framework-datatables [djangorestframework-datatables]: https://github.com/izimobil/django-rest-framework-datatables
[django-rest-framework-condition]: https://github.com/jozo/django-rest-framework-condition [django-rest-framework-condition]: https://github.com/jozo/django-rest-framework-condition
[django-rest-witchcraft]: https://github.com/shosca/django-rest-witchcraft

View File

@ -84,7 +84,7 @@ continued development by **[signing up for a paid plan][funding]**.
REST framework requires the following: REST framework requires the following:
* Python (3.4, 3.5, 3.6, 3.7) * Python (3.5, 3.6, 3.7)
* Django (1.11, 2.0, 2.1, 2.2) * Django (1.11, 2.0, 2.1, 2.2)
We **highly recommend** and only officially support the latest patch release of We **highly recommend** and only officially support the latest patch release of
@ -93,7 +93,7 @@ each Python and Django series.
The following packages are optional: The following packages are optional:
* [coreapi][coreapi] (1.32.0+) - Schema generation support. * [coreapi][coreapi] (1.32.0+) - Schema generation support.
* [Markdown][markdown] (2.1.0+) - Markdown support for the browsable API. * [Markdown][markdown] (2.6.0+) - Markdown support for the browsable API.
* [django-filter][django-filter] (1.0.1+) - Filtering support. * [django-filter][django-filter] (1.0.1+) - Filtering support.
* [django-crispy-forms][django-crispy-forms] - Improved HTML display for filtering. * [django-crispy-forms][django-crispy-forms] - Improved HTML display for filtering.
* [django-guardian][django-guardian] (1.1.1+) - Object level permissions support. * [django-guardian][django-guardian] (1.1.1+) - Object level permissions support.

View File

@ -1,4 +1,4 @@
# Pytest for running the tests. # Pytest for running the tests.
pytest==4.3.0 pytest>=4.5.0,<4.6
pytest-django==3.4.8 pytest-django>=3.4.8,<3.5
pytest-cov==2.6.1 pytest-cov>=2.7.1

View File

@ -143,19 +143,12 @@ if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch'] View.http_method_names = View.http_method_names + ['patch']
# Markdown is optional # Markdown is optional (version 2.6+ required)
try: try:
import markdown import markdown
if markdown.version <= '2.2': HEADERID_EXT_PATH = 'markdown.extensions.toc'
HEADERID_EXT_PATH = 'headerid' LEVEL_PARAM = 'baselevel'
LEVEL_PARAM = 'level'
elif markdown.version < '2.6':
HEADERID_EXT_PATH = 'markdown.extensions.headerid'
LEVEL_PARAM = 'level'
else:
HEADERID_EXT_PATH = 'markdown.extensions.toc'
LEVEL_PARAM = 'baselevel'
def apply_markdown(text): def apply_markdown(text):
""" """

View File

@ -614,7 +614,7 @@ class Field:
for item in self._args for item in self._args
] ]
kwargs = { kwargs = {
key: (copy.deepcopy(value) if (key not in ('validators', 'regex')) else value) key: (copy.deepcopy(value, memo) if (key not in ('validators', 'regex')) else value)
for key, value in self._kwargs.items() for key, value in self._kwargs.items()
} }
return self.__class__(*args, **kwargs) return self.__class__(*args, **kwargs)
@ -1663,11 +1663,13 @@ class DictField(Field):
child = _UnvalidatedField() child = _UnvalidatedField()
initial = {} initial = {}
default_error_messages = { default_error_messages = {
'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".') 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".'),
'empty': _('This dictionary may not be empty.'),
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', copy.deepcopy(self.child))
self.allow_empty = kwargs.pop('allow_empty', True)
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(self.child), '`child` has not been instantiated.'
assert self.child.source is None, ( assert self.child.source is None, (
@ -1693,6 +1695,9 @@ class DictField(Field):
data = html.parse_html_dict(data) data = html.parse_html_dict(data)
if not isinstance(data, dict): if not isinstance(data, dict):
self.fail('not_a_dict', input_type=type(data).__name__) self.fail('not_a_dict', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
return self.run_child_validation(data) return self.run_child_validation(data)
def to_representation(self, value): def to_representation(self, value):

View File

@ -37,6 +37,9 @@ class BaseFilterBackend:
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [] return []
def get_schema_operation_parameters(self, view):
return []
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search. # The URL query parameter used for the search.
@ -156,6 +159,19 @@ class SearchFilter(BaseFilterBackend):
) )
] ]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_text(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering. # The URL query parameter used for the ordering.
@ -287,6 +303,19 @@ class OrderingFilter(BaseFilterBackend):
) )
] ]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_text(self.ordering_description),
'schema': {
'type': 'string',
},
},
]
class DjangoObjectPermissionsFilter(BaseFilterBackend): class DjangoObjectPermissionsFilter(BaseFilterBackend):
""" """

View File

@ -1,41 +1,56 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi from rest_framework import renderers
from rest_framework.renderers import ( from rest_framework.schemas import coreapi
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer from rest_framework.schemas.openapi import SchemaGenerator
)
from rest_framework.schemas.generators import SchemaGenerator OPENAPI_MODE = 'openapi'
COREAPI_MODE = 'coreapi'
class Command(BaseCommand): class Command(BaseCommand):
help = "Generates configured API schema for project." help = "Generates configured API schema for project."
def get_mode(self):
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default=None, type=str) parser.add_argument('--title', dest="title", default='', type=str)
parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str)
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) if self.get_mode() == COREAPI_MODE:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
else:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
def handle(self, *args, **options): def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.' generator_class = self.get_generator_class()
generator = generator_class(
generator = SchemaGenerator(
url=options['url'], url=options['url'],
title=options['title'], title=options['title'],
description=options['description'] description=options['description']
) )
schema = generator.get_schema(request=None, public=True) schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format']) renderer = self.get_renderer(options['format'])
output = renderer.render(schema, renderer_context={}) output = renderer.render(schema, renderer_context={})
self.stdout.write(output.decode()) self.stdout.write(output.decode())
def get_renderer(self, format): def get_renderer(self, format):
renderer_cls = { if self.get_mode() == COREAPI_MODE:
'corejson': CoreJSONRenderer, renderer_cls = {
'openapi': OpenAPIRenderer, 'corejson': renderers.CoreJSONRenderer,
'openapi-json': JSONOpenAPIRenderer, 'openapi': renderers.CoreAPIOpenAPIRenderer,
}[format] 'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
}[format]
return renderer_cls()
renderer_cls = {
'openapi': renderers.OpenAPIRenderer,
'openapi-json': renderers.JSONOpenAPIRenderer,
}[format]
return renderer_cls() return renderer_cls()
def get_generator_class(self):
if self.get_mode() == COREAPI_MODE:
return coreapi.SchemaGenerator
return SchemaGenerator

View File

@ -148,6 +148,9 @@ class BasePagination:
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
return [] return []
def get_schema_operation_parameters(self, view):
return []
class PageNumberPagination(BasePagination): class PageNumberPagination(BasePagination):
""" """
@ -301,6 +304,32 @@ class PageNumberPagination(BasePagination):
) )
return fields return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.page_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_query_description),
'schema': {
'type': 'integer',
},
},
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
},
)
return parameters
class LimitOffsetPagination(BasePagination): class LimitOffsetPagination(BasePagination):
""" """
@ -430,6 +459,15 @@ class LimitOffsetPagination(BasePagination):
context = self.get_html_context() context = self.get_html_context()
return template.render(context) return template.render(context)
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
@ -454,14 +492,28 @@ class LimitOffsetPagination(BasePagination):
) )
] ]
def get_count(self, queryset): def get_schema_operation_parameters(self, view):
""" parameters = [
Determine an object count, supporting either querysets or regular lists. {
""" 'name': self.limit_query_param,
try: 'required': False,
return queryset.count() 'in': 'query',
except (AttributeError, TypeError): 'description': force_text(self.limit_query_description),
return len(queryset) 'schema': {
'type': 'integer',
},
},
{
'name': self.offset_query_param,
'required': False,
'in': 'query',
'description': force_text(self.offset_query_description),
'schema': {
'type': 'integer',
},
},
]
return parameters
class CursorPagination(BasePagination): class CursorPagination(BasePagination):
@ -816,3 +868,29 @@ class CursorPagination(BasePagination):
) )
) )
return fields return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.cursor_query_param,
'required': False,
'in': 'query',
'description': force_text(self.cursor_query_description),
'schema': {
'type': 'integer',
},
}
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_text(self.page_size_query_description),
'schema': {
'type': 'integer',
},
}
)
return parameters

View File

@ -1013,28 +1013,49 @@ class _BaseOpenAPIRenderer:
} }
class OpenAPIRenderer(_BaseOpenAPIRenderer): class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer):
media_type = 'application/vnd.oai.openapi' media_type = 'application/vnd.oai.openapi'
charset = None charset = None
format = 'openapi' format = 'openapi'
def __init__(self): def __init__(self):
assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.'
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.'
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data) structure = self.get_structure(data)
return yaml.dump(structure, default_flow_style=False).encode() return yaml.dump(structure, default_flow_style=False).encode()
class JSONOpenAPIRenderer(_BaseOpenAPIRenderer): class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer):
media_type = 'application/vnd.oai.openapi+json' media_type = 'application/vnd.oai.openapi+json'
charset = None charset = None
format = 'openapi-json' format = 'openapi-json'
def __init__(self): def __init__(self):
assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.' assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.'
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data) structure = self.get_structure(data)
return json.dumps(structure, indent=4).encode() return json.dumps(structure, indent=4).encode('utf-8')
class OpenAPIRenderer(BaseRenderer):
media_type = 'application/vnd.oai.openapi'
charset = None
format = 'openapi'
def __init__(self):
assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
def render(self, data, media_type=None, renderer_context=None):
return yaml.dump(data, default_flow_style=False).encode('utf-8')
class JSONOpenAPIRenderer(BaseRenderer):
media_type = 'application/vnd.oai.openapi+json'
charset = None
format = 'openapi-json'
def render(self, data, media_type=None, renderer_context=None):
return json.dumps(data, indent=2).encode('utf-8')

View File

@ -22,24 +22,32 @@ Other access should target the submodules directly
""" """
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from .generators import SchemaGenerator from . import coreapi, openapi
from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa from .inspectors import DefaultSchema # noqa
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
def get_schema_view( def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None, title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator, public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
""" """
Return a schema view. Return a schema view.
""" """
# Avoid import cycle on APIView if generator_class is None:
from .views import SchemaView if coreapi.is_enabled():
generator_class = coreapi.SchemaGenerator
else:
generator_class = openapi.SchemaGenerator
generator = generator_class( generator = generator_class(
title=title, url=url, description=description, title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns, urlconf=urlconf, patterns=patterns,
) )
# Avoid import cycle on APIView
from .views import SchemaView
return SchemaView.as_view( return SchemaView.as_view(
renderer_classes=renderer_classes, renderer_classes=renderer_classes,
schema_generator=generator, schema_generator=generator,

View File

@ -0,0 +1,619 @@
import re
import warnings
from collections import Counter, OrderedDict
from urllib import parse
from django.db import models
from django.utils.encoding import force_text, smart_text
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Used in _get_description_section()
# TODO: ???: move up to base.
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
# Generator #
# TODO: Pull some of this into base.
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super(LinkNode, self).__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
if len(keys) == 1:
target[keys[-1]] = LinkNode()
target = target[keys[-1]]
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = LinkNode()
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
# View Inspectors #
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super(AutoSchema, self).__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
def is_enabled():
"""Is CoreAPI Mode enabled?"""
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@ -4,7 +4,6 @@ generators.py # Top-down schema generation
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
import re import re
from collections import Counter, OrderedDict
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings
@ -13,15 +12,11 @@ from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import ( from rest_framework.compat import URLPattern, URLResolver, get_original_route
URLPattern, URLResolver, coreapi, coreschema, get_original_route
)
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk from rest_framework.utils.model_meta import _get_pk
from .utils import is_list_view
def common_path(paths): def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths] split_paths = [path.strip('/').split('/') for path in paths]
@ -50,81 +45,6 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView) return (cls is not None) and issubclass(cls, APIView)
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super().__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
if len(keys) == 1:
target[keys[-1]] = LinkNode()
target = target[keys[-1]]
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def endpoint_ordering(endpoint): def endpoint_ordering(endpoint):
path, method, callback = endpoint path, method, callback = endpoint
method_priority = { method_priority = {
@ -193,6 +113,10 @@ class EndpointEnumerator:
""" """
Given a URL conf regex, return a URI template string. Given a URL conf regex, return a URI template string.
""" """
# ???: Would it be feasible to adjust this such that we generate the
# path, plus the kwargs, plus the type from the convertor, such that we
# could feed that straight into the parameter schema object?
path = simplify_regex(path_regex) path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format # Strip Django 2.0 convertors as they are incompatible with uritemplate format
@ -231,35 +155,18 @@ class EndpointEnumerator:
return [method for method in methods if method not in ('OPTIONS', 'HEAD')] return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class SchemaGenerator: class BaseSchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
endpoint_inspector_cls = EndpointEnumerator endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier, # 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas. # so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'. # Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'): if url and not url.endswith('/'):
url += '/' url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns self.patterns = patterns
@ -269,36 +176,15 @@ class SchemaGenerator:
self.url = url self.url = url
self.endpoints = None self.endpoints = None
def get_schema(self, request=None, public=False): def _initialise_endpoints(self):
"""
Generate a `coreapi.Document` representing the API schema.
"""
if self.endpoints is None: if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints() self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request) def _get_paths_and_endpoints(self, request):
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
""" """
Return a dictionary containing all the links that should be Generate (path, method, view) given (path, method, callback) for paths.
included in the API schema.
""" """
links = LinkNode()
# Generate (path, method, view) given (path, method, callback).
paths = [] paths = []
view_endpoints = [] view_endpoints = []
for path, method, callback in self.endpoints: for path, method, callback in self.endpoints:
@ -307,22 +193,48 @@ class SchemaGenerator:
paths.append(path) paths.append(path)
view_endpoints.append((path, method, view)) view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included return paths, view_endpoints
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints: def create_view(self, callback, method, request=None):
if not self.has_view_permissions(path, method, view): """
continue Given a callback, return an actual view instance.
link = view.schema.get_link(path, method, base_url=self.url) """
subpath = path[len(prefix):] view = callback.cls(**getattr(callback, 'initkwargs', {}))
keys = self.get_keys(subpath, method, view) view.args = ()
insert_into(links, keys, link) view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
return links actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
# Methods used when we generate a view instance from the raw callback... if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def determine_path_prefix(self, paths): def determine_path_prefix(self, paths):
""" """
@ -355,29 +267,6 @@ class SchemaGenerator:
prefixes.append('/' + prefix + '/') prefixes.append('/' + prefix + '/')
return common_path(prefixes) return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view): def has_view_permissions(self, path, method, view):
""" """
Return `True` if the incoming request has the correct view permissions. Return `True` if the incoming request has the correct view permissions.
@ -390,64 +279,3 @@ class SchemaGenerator:
except (exceptions.APIException, Http404, PermissionDenied): except (exceptions.APIException, Http404, PermissionDenied):
return False return False
return True return True
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

View File

@ -3,125 +3,9 @@ inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
import re
import warnings
from collections import OrderedDict
from urllib import parse
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
from django.db import models
from django.utils.encoding import force_text, smart_text
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class ViewInspector: class ViewInspector:
@ -178,320 +62,6 @@ class ViewInspector:
def view(self): def view(self):
self._view = None self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super().__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
return list(by_name.values())
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super().__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
class DefaultSchema(ViewInspector): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""

View File

@ -0,0 +1,377 @@
import warnings
from django.db import models
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers
from rest_framework.compat import uritemplate
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
# Generator
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
info = {
'title': self.title,
'version': 'TODO',
}
if self.description is not None:
info['description'] = self.description
return info
def get_paths(self, request=None):
result = {}
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
subpath = '/' + path[len(prefix):]
result.setdefault(subpath, {})
result[subpath][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
paths = self.get_paths(None if public else request)
if not paths:
return None
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
return schema
# View Inspectors
class AutoSchema(ViewInspector):
content_types = ['application/json']
method_mapping = {
'get': 'Retrieve',
'post': 'Create',
'put': 'Update',
'patch': 'PartialUpdate',
'delete': 'Destroy',
}
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self._get_operation_id(path, method)
parameters = []
parameters += self._get_path_parameters(path, method)
parameters += self._get_pagination_parameters(path, method)
parameters += self._get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self._get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
return operation
def _get_operation_id(self, path, method):
"""
Compute an operation ID from the model, serializer or view name.
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'List'
elif method_name not in self.method_mapping:
action = method_name
else:
action = self.method_mapping[method.lower()]
# Try to deduce the ID from the view's model
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if model is not None:
name = model.__name__
# Try with the serializer class name
elif hasattr(self.view, 'get_serializer_class'):
name = self.view.get_serializer_class().__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
if name.endswith(action): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'List' and not name.endswith('s'): # ListThings instead of ListThing
name += 's'
return action + name
def _get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
}
parameters.append(parameter)
return parameters
def _get_filter_parameters(self, path, method):
if not self._allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def _get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_operation_parameters(view)
def _map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self._map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
data = self._map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self._map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': {
'enum': list(field.choices)
},
}
if isinstance(field, serializers.ChoiceField):
return {
'enum': list(field.choices),
}
# ListField.
if isinstance(field, serializers.ListField):
return {
'type': 'array',
}
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
}
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.DecimalField: 'number',
serializers.FloatField: 'number',
serializers.IntegerField: 'integer',
serializers.JSONField: 'object',
serializers.DictField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
# TODO:
# - field is Nested or List serializer.
# - Handle read_only/write_only for request/response differences.
# - could do this with readOnly/writeOnly and then filter dict.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self._map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
properties[field.field_name] = schema
return {
'required': required,
'properties': properties,
}
def _get_request_body(self, path, method):
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return {}
if not hasattr(view, 'get_serializer'):
return {}
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if not isinstance(serializer, serializers.Serializer):
return {}
content = self._map_serializer(serializer)
# No required fields for PATCH
if method == 'PATCH':
del content['required']
# No read_only fields for request.
for name, schema in content['properties'].copy().items():
if 'readOnly' in schema:
del content['properties'][name]
return {
'content': {
ct: {'schema': content}
for ct in self.content_types
}
}
def _get_responses(self, path, method):
# TODO: Handle multiple codes.
content = {}
view = self.view
if hasattr(view, 'get_serializer'):
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.Serializer):
content = self._map_serializer(serializer)
# No write_only fields for response.
for name, schema in content['properties'].copy().items():
if 'writeOnly' in schema:
del content['properties'][name]
content['required'] = [f for f in content['required'] if f != name]
return {
'200': {
'content': {
ct: {'schema': content}
for ct in self.content_types
}
}
}

View File

@ -3,6 +3,9 @@ utils.py # Shared helper functions
See schemas.__init__.py for package overview. See schemas.__init__.py for package overview.
""" """
from django.db import models
from django.utils.translation import ugettext_lazy as _
from rest_framework.mixins import RetrieveModelMixin from rest_framework.mixins import RetrieveModelMixin
@ -22,3 +25,17 @@ def is_list_view(path, method, view):
if path_components and '{' in path_components[-1]: if path_components and '{' in path_components[-1]:
return False return False
return True return True
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)

View File

@ -5,6 +5,7 @@ See schemas.__init__.py for package overview.
""" """
from rest_framework import exceptions, renderers from rest_framework import exceptions, renderers
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import coreapi
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.views import APIView from rest_framework.views import APIView
@ -19,10 +20,16 @@ class SchemaView(APIView):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.renderer_classes is None: if self.renderer_classes is None:
self.renderer_classes = [ if coreapi.is_enabled():
renderers.OpenAPIRenderer, self.renderer_classes = [
renderers.CoreJSONRenderer renderers.CoreAPIOpenAPIRenderer,
] renderers.CoreJSONRenderer
]
else:
self.renderer_classes = [
renderers.OpenAPIRenderer,
renderers.JSONOpenAPIRenderer,
]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer] self.renderer_classes += [renderers.BrowsableAPIRenderer]

View File

@ -52,7 +52,7 @@ DEFAULTS = {
'DEFAULT_FILTER_BACKENDS': (), 'DEFAULT_FILTER_BACKENDS': (),
# Schema # Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
# Throttling # Throttling
'DEFAULT_THROTTLE_RATES': { 'DEFAULT_THROTTLE_RATES': {

View File

@ -8,7 +8,7 @@ from io import open
from setuptools import find_packages, setup from setuptools import find_packages, setup
CURRENT_PYTHON = sys.version_info[:2] CURRENT_PYTHON = sys.version_info[:2]
REQUIRED_PYTHON = (3, 4) REQUIRED_PYTHON = (3, 5)
# This check and everything above must remain compatible with Python 2.7. # This check and everything above must remain compatible with Python 2.7.
if CURRENT_PYTHON < REQUIRED_PYTHON: if CURRENT_PYTHON < REQUIRED_PYTHON:
@ -79,7 +79,7 @@ setup(
packages=find_packages(exclude=['tests*']), packages=find_packages(exclude=['tests*']),
include_package_data=True, include_package_data=True,
install_requires=[], install_requires=[],
python_requires=">=3.4", python_requires=">=3.5",
zip_safe=False, zip_safe=False,
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
@ -94,7 +94,6 @@ setup(
'Operating System :: OS Independent', 'Operating System :: OS Independent',
'Programming Language :: Python', 'Programming Language :: Python',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',

View File

View File

@ -16,15 +16,16 @@ from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import ( from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
) )
from rest_framework.schemas.coreapi import field_to_schema
from rest_framework.schemas.generators import EndpointEnumerator from rest_framework.schemas.generators import EndpointEnumerator
from rest_framework.schemas.inspectors import field_to_schema
from rest_framework.schemas.utils import is_list_view from rest_framework.schemas.utils import is_list_view
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.utils import formatting from rest_framework.utils import formatting
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import BasicModel, ForeignKeySource, ManyToManySource from . import views
from ..models import BasicModel, ForeignKeySource, ManyToManySource
factory = APIRequestFactory() factory = APIRequestFactory()
@ -133,11 +134,12 @@ class ExampleViewSet(ModelViewSet):
pass pass
if coreapi: with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
schema_view = get_schema_view(title='Example API') if coreapi:
else: schema_view = get_schema_view(title='Example API')
def schema_view(request): else:
pass def schema_view(request):
pass
router = DefaultRouter() router = DefaultRouter()
router.register('example', ExampleViewSet, basename='example') router.register('example', ExampleViewSet, basename='example')
@ -148,7 +150,7 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas') @override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestRouterGeneratedSchema(TestCase): class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self): def test_anonymous_request(self):
client = APIClient() client = APIClient()
@ -400,12 +402,13 @@ class ExampleDetailView(APIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGenerator(TestCase): class TestSchemaGenerator(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
url(r'^example/?$', ExampleListView.as_view()), url(r'^example/?$', views.ExampleListView.as_view()),
url(r'^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()), url(r'^example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()), url(r'^example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -453,12 +456,13 @@ class TestSchemaGenerator(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@unittest.skipUnless(path, 'needs Django 2') @unittest.skipUnless(path, 'needs Django 2')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorDjango2(TestCase): class TestSchemaGeneratorDjango2(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
path('example/', ExampleListView.as_view()), path('example/', views.ExampleListView.as_view()),
path('example/<int:pk>/', ExampleDetailView.as_view()), path('example/<int:pk>/', views.ExampleDetailView.as_view()),
path('example/<int:pk>/sub/', ExampleDetailView.as_view()), path('example/<int:pk>/sub/', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -505,12 +509,13 @@ class TestSchemaGeneratorDjango2(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorNotAtRoot(TestCase): class TestSchemaGeneratorNotAtRoot(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
url(r'^api/v1/example/?$', ExampleListView.as_view()), url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P<pk>\d+)/?$', views.ExampleDetailView.as_view()),
url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P<pk>\d+)/sub/?$', views.ExampleDetailView.as_view()),
] ]
def test_schema_for_regular_views(self): def test_schema_for_regular_views(self):
@ -558,6 +563,7 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
def setUp(self): def setUp(self):
router = DefaultRouter() router = DefaultRouter()
@ -622,13 +628,14 @@ class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithRestrictedViewSets(TestCase): class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
def setUp(self): def setUp(self):
router = DefaultRouter() router = DefaultRouter()
router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example1', Http404ExampleViewSet, basename='example1')
router.register('example2', PermissionDeniedExampleViewSet, basename='example2') router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
self.patterns = [ self.patterns = [
url('^example/?$', ExampleListView.as_view()), url('^example/?$', views.ExampleListView.as_view()),
url(r'^', include(router.urls)) url(r'^', include(router.urls))
] ]
@ -668,6 +675,7 @@ class ForeignKeySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithForeignKey(TestCase): class TestSchemaGeneratorWithForeignKey(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
@ -713,6 +721,7 @@ class ManyToManySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorWithManyToMany(TestCase): class TestSchemaGeneratorWithManyToMany(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
@ -747,6 +756,7 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class Test4605Regression(TestCase): class Test4605Regression(TestCase):
def test_4605_regression(self): def test_4605_regression(self):
generator = SchemaGenerator() generator = SchemaGenerator()
@ -762,6 +772,7 @@ class CustomViewInspector(AutoSchema):
pass pass
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestAutoSchema(TestCase): class TestAutoSchema(TestCase):
def test_apiview_schema_descriptor(self): def test_apiview_schema_descriptor(self):
@ -777,7 +788,7 @@ class TestAutoSchema(TestCase):
assert isinstance(view.schema, CustomViewInspector) assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self): def test_set_custom_inspector_class_via_settings(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}):
view = APIView() view = APIView()
assert isinstance(view.schema, CustomViewInspector) assert isinstance(view.schema, CustomViewInspector)
@ -971,6 +982,7 @@ class TestAutoSchema(TestCase):
self.assertEqual(field_to_schema(case[0]), case[1]) self.assertEqual(field_to_schema(case[0]), case[1])
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_docstring_is_not_stripped_by_get_description(): def test_docstring_is_not_stripped_by_get_description():
class ExampleDocstringAPIView(APIView): class ExampleDocstringAPIView(APIView):
""" """
@ -1007,25 +1019,25 @@ def test_docstring_is_not_stripped_by_get_description():
# Views for SchemaGenerationExclusionTests # Views for SchemaGenerationExclusionTests
class ExcludedAPIView(APIView): with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
schema = None class ExcludedAPIView(APIView):
schema = None
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass pass
@api_view(['GET'])
@schema(None)
def excluded_fbv(request):
pass
@api_view(['GET'])
def included_fbv(request):
pass
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class SchemaGenerationExclusionTests(TestCase): class SchemaGenerationExclusionTests(TestCase):
def setUp(self): def setUp(self):
self.patterns = [ self.patterns = [
@ -1078,11 +1090,6 @@ class SchemaGenerationExclusionTests(TestCase):
assert should_include == expected assert should_include == expected
@api_view(["GET"])
def simple_fbv(request):
pass
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
@ -1118,11 +1125,16 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestURLNamingCollisions(TestCase): class TestURLNamingCollisions(TestCase):
""" """
Ref: https://github.com/encode/django-rest-framework/issues/4704 Ref: https://github.com/encode/django-rest-framework/issues/4704
""" """
def test_manually_routing_nested_routes(self): def test_manually_routing_nested_routes(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [ patterns = [
url(r'^test', simple_fbv), url(r'^test', simple_fbv),
url(r'^test/list/', simple_fbv), url(r'^test/list/', simple_fbv),
@ -1229,6 +1241,10 @@ class TestURLNamingCollisions(TestCase):
def test_url_under_same_key_not_replaced_another(self): def test_url_under_same_key_not_replaced_another(self):
@api_view(["GET"])
def simple_fbv(request):
pass
patterns = [ patterns = [
url(r'^test/list/', simple_fbv), url(r'^test/list/', simple_fbv),
url(r'^test/(?P<pk>\d+)/list/', simple_fbv), url(r'^test/(?P<pk>\d+)/list/', simple_fbv),
@ -1303,10 +1319,8 @@ def test_head_and_options_methods_are_excluded():
assert inspector.get_allowed_methods(callback) == ["GET"] assert inspector.get_allowed_methods(callback) == ["GET"]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') class MockAPIView(APIView):
class TestAutoSchemaAllowsFilters: filter_backends = [filters.OrderingFilter]
class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter]
def _test(self, method): def _test(self, method):
view = self.MockAPIView() view = self.MockAPIView()

View File

@ -0,0 +1,20 @@
import pytest
from django.test import TestCase, override_settings
from rest_framework import renderers
from rest_framework.schemas import coreapi, get_schema_view, openapi
class GetSchemaViewTests(TestCase):
"""For the get_schema_view() helper."""
def test_openapi(self):
schema_view = get_schema_view(title="With OpenAPI")
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes
@pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed')
def test_coreapi(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
schema_view = get_schema_view(title="With CoreAPI")
assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator)
assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes

View File

@ -6,7 +6,8 @@ from django.core.management import call_command
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from rest_framework.compat import coreapi from rest_framework.compat import uritemplate, yaml
from rest_framework.management.commands import generateschema
from rest_framework.utils import formatting, json from rest_framework.utils import formatting, json
from rest_framework.views import APIView from rest_framework.views import APIView
@ -21,15 +22,43 @@ urlpatterns = [
] ]
@override_settings(ROOT_URLCONF='tests.test_generateschema') @override_settings(ROOT_URLCONF=__name__)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed')
class GenerateSchemaTests(TestCase): class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema.""" """Tests for management command generateschema."""
def setUp(self): def setUp(self):
self.out = io.StringIO() self.out = io.StringIO()
def test_command_detects_schema_generation_mode(self):
"""Switching between CoreAPI & OpenAPI"""
command = generateschema.Command()
assert command.get_mode() == generateschema.OPENAPI_MODE
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
assert command.get_mode() == generateschema.COREAPI_MODE
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
call_command('generateschema',
'--title=SampleAPI',
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
# Check valid YAML was output.
schema = yaml.load(self.out.getvalue())
assert schema['openapi'] == '3.0.2'
def test_renders_openapi_json_schema(self):
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
# Check valid JSON was output.
out_json = json.loads(self.out.getvalue())
assert out_json['openapi'] == '3.0.2'
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info: expected_out = """info:
description: Sample description description: Sample description
title: SampleAPI title: SampleAPI
@ -50,7 +79,8 @@ class GenerateSchemaTests(TestCase):
self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self): @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_coreapi_renders_openapi_json_schema(self):
expected_out = { expected_out = {
"openapi": "3.0.0", "openapi": "3.0.0",
"info": { "info": {
@ -78,6 +108,7 @@ class GenerateSchemaTests(TestCase):
self.assertDictEqual(out_json, expected_out) self.assertDictEqual(out_json, expected_out)
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_renders_corejson_schema(self): def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', call_command('generateschema',

View File

@ -0,0 +1,245 @@
import pytest
from django.conf.urls import url
from django.test import RequestFactory, TestCase, override_settings
from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.compat import uritemplate
from rest_framework.request import Request
from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
from . import views
def create_request(path):
factory = RequestFactory()
request = Request(factory.get(path))
return request
def create_view(view_cls, method, request):
generator = SchemaGenerator()
view = generator.create_view(view_cls.as_view(), method, request)
return view
class TestBasics(TestCase):
def dummy_view(request):
pass
def test_filters(self):
classes = [filters.SearchFilter, filters.OrderingFilter]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
def test_pagination(self):
classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination]
for c in classes:
f = c()
assert f.get_schema_operation_parameters(self.dummy_view)
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
class TestOperationIntrospection(TestCase):
def test_path_without_parameters(self):
path = '/example/'
method = 'GET'
view = create_view(
views.ExampleListView,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'ListExamples',
'parameters': [],
'responses': {'200': {'content': {'application/json': {'schema': {}}}}},
}
def test_path_with_id_parameter(self):
path = '/example/{id}/'
method = 'GET'
view = create_view(
views.ExampleDetailView,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
parameters = inspector._get_path_parameters(path, method)
assert parameters == [{
'description': '',
'in': 'path',
'name': 'id',
'required': True,
'schema': {
'type': 'string',
},
}]
def test_request_body(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
text = serializers.CharField()
read_only = serializers.CharField(read_only=True)
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
request_body = inspector._get_request_body(path, method)
assert request_body['content']['application/json']['schema']['required'] == ['text']
assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text']
def test_response_body_generation(self):
path = '/'
method = 'POST'
class Serializer(serializers.Serializer):
text = serializers.CharField()
write_only = serializers.CharField(write_only=True)
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path)
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
assert responses['200']['content']['application/json']['schema']['required'] == ['text']
assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text']
def test_response_body_nested_serializer(self):
path = '/'
method = 'POST'
class NestedSerializer(serializers.Serializer):
number = serializers.IntegerField()
class Serializer(serializers.Serializer):
text = serializers.CharField()
nested = NestedSerializer()
class View(generics.GenericAPIView):
serializer_class = Serializer
view = create_view(
View,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
responses = inspector._get_responses(path, method)
schema = responses['200']['content']['application/json']['schema']
assert sorted(schema['required']) == ['nested', 'text']
assert sorted(list(schema['properties'].keys())) == ['nested', 'text']
assert schema['properties']['nested']['type'] == 'object'
assert list(schema['properties']['nested']['properties'].keys()) == ['number']
assert schema['properties']['nested']['required'] == ['number']
def test_operation_id_generation(self):
path = '/'
method = 'GET'
view = create_view(
views.ExampleGenericAPIView,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view
operationId = inspector._get_operation_id(path, method)
assert operationId == 'ListExamples'
def test_repeat_operation_ids(self):
router = routers.SimpleRouter()
router.register('account', views.ExampleGenericViewSet, basename="account")
urlpatterns = router.urls
generator = SchemaGenerator(patterns=urlpatterns)
request = create_request('/')
schema = generator.get_schema(request=request)
schema_str = str(schema)
print(schema_str)
assert schema_str.count("operationId") == 2
assert schema_str.count("newExample") == 1
assert schema_str.count("oldExample") == 1
@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'})
class TestGenerator(TestCase):
def test_override_settings(self):
assert isinstance(views.ExampleListView.schema, AutoSchema)
def test_paths_construction(self):
"""Construction of the `paths` key."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()
paths = generator.get_paths()
assert '/example/' in paths
example_operations = paths['/example/']
assert len(example_operations) == 2
assert 'get' in example_operations
assert 'post' in example_operations
def test_schema_construction(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'openapi' in schema
assert 'paths' in schema
def test_serializer_datefield(self):
patterns = [
url(r'^example/?$', views.ExampleGenericViewSet.as_view({"get": "get"})),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
response = schema['paths']['/example/']['get']['responses']
response_schema = response['200']['content']['application/json']['schema']['properties']
assert response_schema['date']['type'] == response_schema['datetime']['type'] == 'string'
assert response_schema['date']['format'] == 'date'
assert response_schema['datetime']['format'] == 'date-time'

View File

@ -39,11 +39,12 @@ class ExampleViewSet(GenericViewSet):
return Response({}) return Response({})
if coreapi: with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
schema_view = get_schema_view(title='Example API') if coreapi:
else: schema_view = get_schema_view(title='Example API')
def schema_view(request): else:
pass def schema_view(request):
pass
router = SimpleRouter() router = SimpleRouter()
router.register('example', ExampleViewSet, basename='example') router.register('example', ExampleViewSet, basename='example')
@ -54,7 +55,7 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schema_with_single_common_prefix') @override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class AnotherTestRouterGeneratedSchema(TestCase): class AnotherTestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self): def test_anonymous_request(self):
client = APIClient() client = APIClient()

58
tests/schemas/views.py Normal file
View File

@ -0,0 +1,58 @@
from rest_framework import generics, permissions, serializers
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet
class ExampleListView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
class ExampleDetailView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, *args, **kwargs):
pass
# Generics.
class ExampleSerializer(serializers.Serializer):
date = serializers.DateField()
datetime = serializers.DateTimeField()
class ExampleGenericAPIView(generics.GenericAPIView):
serializer_class = ExampleSerializer
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
class ExampleGenericViewSet(GenericViewSet):
serializer_class = ExampleSerializer
def get(self, *args, **kwargs):
from datetime import datetime
now = datetime.now()
serializer = self.get_serializer(data=now.date(), datetime=now)
return Response(serializer.data)
@action(detail=False)
def new(self, *args, **kwargs):
pass
@action(detail=False)
def old(self, *args, **kwargs):
pass

View File

@ -1,7 +1,6 @@
import datetime import datetime
import os import os
import re import re
import unittest
import uuid import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal from decimal import ROUND_DOWN, ROUND_UP, Decimal
@ -17,15 +16,10 @@ from rest_framework import exceptions, serializers
from rest_framework.compat import ProhibitNullCharactersValidator from rest_framework.compat import ProhibitNullCharactersValidator
from rest_framework.fields import DjangoImageField, is_simple_callable from rest_framework.fields import DjangoImageField, is_simple_callable
try:
import typing
except ImportError:
typing = False
# Tests for helper functions. # Tests for helper functions.
# --------------------------- # ---------------------------
class TestIsSimpleCallable: class TestIsSimpleCallable:
def test_method(self): def test_method(self):
@ -92,7 +86,6 @@ class TestIsSimpleCallable:
assert is_simple_callable(ChoiceModel().get_choice_field_display) assert is_simple_callable(ChoiceModel().get_choice_field_display)
@unittest.skipUnless(typing, 'requires python 3.5')
def test_type_annotation(self): def test_type_annotation(self):
# The annotation will otherwise raise a syntax error in python < 3.5 # The annotation will otherwise raise a syntax error in python < 3.5
locals = {} locals = {}
@ -1989,6 +1982,7 @@ class TestDictField(FieldValues):
""" """
valid_inputs = [ valid_inputs = [
({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
({}, {}),
] ]
invalid_inputs = [ invalid_inputs = [
({'a': 1, 'b': None, 'c': None}, {'b': ['This field may not be null.'], 'c': ['This field may not be null.']}), ({'a': 1, 'b': None, 'c': None}, {'b': ['This field may not be null.'], 'c': ['This field may not be null.']}),
@ -2016,6 +2010,16 @@ class TestDictField(FieldValues):
output = field.run_validation(None) output = field.run_validation(None)
assert output is None assert output is None
def test_allow_empty_disallowed(self):
"""
If allow_empty is False then an empty dict is not a valid input.
"""
field = serializers.DictField(allow_empty=False)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation({})
assert exc_info.value.detail == ['This dictionary may not be empty.']
class TestNestedDictField(FieldValues): class TestNestedDictField(FieldValues):
""" """

View File

@ -189,7 +189,7 @@ class JsonFloatTests(TestCase):
json.loads("NaN") json.loads("NaN")
@override_settings(STRICT_JSON=False) @override_settings(REST_FRAMEWORK={'STRICT_JSON': False})
class NonStrictJsonFloatTests(JsonFloatTests): class NonStrictJsonFloatTests(JsonFloatTests):
""" """
'STRICT_JSON = False' should not somehow affect internal json behavior 'STRICT_JSON = False' should not somehow affect internal json behavior

View File

@ -1,7 +1,7 @@
[tox] [tox]
envlist = envlist =
{py34,py35,py36}-django111, {py35,py36}-django111,
{py34,py35,py36,py37}-django20, {py35,py36,py37}-django20,
{py35,py36,py37}-django21 {py35,py36,py37}-django21
{py35,py36,py37}-django22 {py35,py36,py37}-django22
{py36,py37}-djangomaster, {py36,py37}-djangomaster,