Fix schemas for extra actions (#5992)

* Add failing test for extra action schemas

* Add ViewInspector setter to store instances

* Fix schema disabling for extra actions

* Add docs note about disabling schemas for actions
This commit is contained in:
Ryan P Kilby 2018-07-06 04:35:36 -04:00 committed by Carlton Gibson
parent b23cdaff4c
commit 6511b52cca
4 changed files with 76 additions and 4 deletions

View File

@ -243,6 +243,14 @@ You may disable schema generation for a view by setting `schema` to `None`:
... ...
schema = None # Will not appear in schema schema = None # Will not appear in schema
This also applies to extra actions for `ViewSet`s:
class CustomViewSet(viewsets.ModelViewSet):
@action(detail=True, schema=None)
def extra_action(self, request, pk=None):
...
--- ---
**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and **Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and

View File

@ -218,6 +218,10 @@ class EndpointEnumerator(object):
if callback.cls.schema is None: if callback.cls.schema is None:
return False return False
if 'schema' in callback.initkwargs:
if callback.initkwargs['schema'] is None:
return False
if path.endswith('.{format}') or path.endswith('.{format}/'): if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs. return False # Ignore .json style URLs.
@ -365,9 +369,7 @@ class SchemaGenerator(object):
""" """
Given a callback, return an actual view instance. Given a callback, return an actual view instance.
""" """
view = callback.cls() view = callback.cls(**getattr(callback, 'initkwargs', {}))
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = () view.args = ()
view.kwargs = {} view.kwargs = {}
view.format_kwarg = None view.format_kwarg = None

View File

@ -7,6 +7,7 @@ See schemas.__init__.py for package overview.
import re import re
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from weakref import WeakKeyDictionary
from django.db import models from django.db import models
from django.utils.encoding import force_text, smart_text from django.utils.encoding import force_text, smart_text
@ -128,6 +129,10 @@ class ViewInspector(object):
Provide subclass for per-view schema generation Provide subclass for per-view schema generation
""" """
def __init__(self):
self.instance_schemas = WeakKeyDictionary()
def __get__(self, instance, owner): def __get__(self, instance, owner):
""" """
Enables `ViewInspector` as a Python _Descriptor_. Enables `ViewInspector` as a Python _Descriptor_.
@ -144,9 +149,17 @@ class ViewInspector(object):
See: https://docs.python.org/3/howto/descriptor.html for info on See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage. descriptor usage.
""" """
if instance in self.instance_schemas:
return self.instance_schemas[instance]
self.view = instance self.view = instance
return self return self
def __set__(self, instance, other):
self.instance_schemas[instance] = other
if other is not None:
other.view = instance
@property @property
def view(self): def view(self):
"""View property.""" """View property."""
@ -189,6 +202,7 @@ class AutoSchema(ViewInspector):
* `manual_fields`: list of `coreapi.Field` instances that * `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name` will be added to auto-generated fields, overwriting on `Field.name`
""" """
super(AutoSchema, self).__init__()
if manual_fields is None: if manual_fields is None:
manual_fields = [] manual_fields = []
self._manual_fields = manual_fields self._manual_fields = manual_fields
@ -455,6 +469,7 @@ class ManualSchema(ViewInspector):
* `fields`: list of `coreapi.Field` instances. * `fields`: list of `coreapi.Field` instances.
* `descripton`: String description for view. Optional. * `descripton`: String description for view. Optional.
""" """
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields self._fields = fields
self._description = description self._description = description
@ -474,9 +489,13 @@ class ManualSchema(ViewInspector):
) )
class DefaultSchema(object): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner): def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
inspector = inspector_class() inspector = inspector_class()

View File

@ -105,6 +105,10 @@ class ExampleViewSet(ModelViewSet):
"""Deletion description.""" """Deletion description."""
raise NotImplementedError raise NotImplementedError
@action(detail=False, schema=None)
def excluded_action(self, request):
pass
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
assert self.request assert self.request
assert self.action assert self.action
@ -735,6 +739,45 @@ class TestAutoSchema(TestCase):
assert len(fields) == 2 assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields] assert "my_extra_field" in [f.name for f in fields]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_viewset_action_with_schema(self):
class CustomViewSet(GenericViewSet):
@action(detail=True, schema=AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
]))
def extra_action(self, pk, **kwargs):
pass
router = SimpleRouter()
router.register(r'detail', CustomViewSet, base_name='detail')
generator = SchemaGenerator()
view = generator.create_view(router.urls[0].callback, 'GET')
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_viewset_action_with_null_schema(self):
class CustomViewSet(GenericViewSet):
@action(detail=True, schema=None)
def extra_action(self, pk, **kwargs):
pass
router = SimpleRouter()
router.register(r'detail', CustomViewSet, base_name='detail')
generator = SchemaGenerator()
view = generator.create_view(router.urls[0].callback, 'GET')
assert view.schema is None
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_view_with_manual_schema(self): def test_view_with_manual_schema(self):