mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-31 07:57:55 +03:00 
			
		
		
		
	Fix schemas for extra actions (#5992)
* Add failing test for extra action schemas * Add ViewInspector setter to store instances * Fix schema disabling for extra actions * Add docs note about disabling schemas for actions
This commit is contained in:
		
							parent
							
								
									b23cdaff4c
								
							
						
					
					
						commit
						6511b52cca
					
				|  | @ -243,6 +243,14 @@ You may disable schema generation for a view by setting `schema` to `None`: | ||||||
|             ... |             ... | ||||||
|             schema = None  # Will not appear in schema |             schema = None  # Will not appear in schema | ||||||
| 
 | 
 | ||||||
|  | This also applies to extra actions for `ViewSet`s: | ||||||
|  | 
 | ||||||
|  |         class CustomViewSet(viewsets.ModelViewSet): | ||||||
|  | 
 | ||||||
|  |             @action(detail=True, schema=None) | ||||||
|  |             def extra_action(self, request, pk=None): | ||||||
|  |                 ... | ||||||
|  | 
 | ||||||
| --- | --- | ||||||
| 
 | 
 | ||||||
| **Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and | **Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and | ||||||
|  |  | ||||||
|  | @ -218,6 +218,10 @@ class EndpointEnumerator(object): | ||||||
|         if callback.cls.schema is None: |         if callback.cls.schema is None: | ||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
|  |         if 'schema' in callback.initkwargs: | ||||||
|  |             if callback.initkwargs['schema'] is None: | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|         if path.endswith('.{format}') or path.endswith('.{format}/'): |         if path.endswith('.{format}') or path.endswith('.{format}/'): | ||||||
|             return False  # Ignore .json style URLs. |             return False  # Ignore .json style URLs. | ||||||
| 
 | 
 | ||||||
|  | @ -365,9 +369,7 @@ class SchemaGenerator(object): | ||||||
|         """ |         """ | ||||||
|         Given a callback, return an actual view instance. |         Given a callback, return an actual view instance. | ||||||
|         """ |         """ | ||||||
|         view = callback.cls() |         view = callback.cls(**getattr(callback, 'initkwargs', {})) | ||||||
|         for attr, val in getattr(callback, 'initkwargs', {}).items(): |  | ||||||
|             setattr(view, attr, val) |  | ||||||
|         view.args = () |         view.args = () | ||||||
|         view.kwargs = {} |         view.kwargs = {} | ||||||
|         view.format_kwarg = None |         view.format_kwarg = None | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ See schemas.__init__.py for package overview. | ||||||
| import re | import re | ||||||
| import warnings | import warnings | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
|  | from weakref import WeakKeyDictionary | ||||||
| 
 | 
 | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.encoding import force_text, smart_text | from django.utils.encoding import force_text, smart_text | ||||||
|  | @ -128,6 +129,10 @@ class ViewInspector(object): | ||||||
| 
 | 
 | ||||||
|     Provide subclass for per-view schema generation |     Provide subclass for per-view schema generation | ||||||
|     """ |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self): | ||||||
|  |         self.instance_schemas = WeakKeyDictionary() | ||||||
|  | 
 | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|         """ |         """ | ||||||
|         Enables `ViewInspector` as a Python _Descriptor_. |         Enables `ViewInspector` as a Python _Descriptor_. | ||||||
|  | @ -144,9 +149,17 @@ class ViewInspector(object): | ||||||
|         See: https://docs.python.org/3/howto/descriptor.html for info on |         See: https://docs.python.org/3/howto/descriptor.html for info on | ||||||
|         descriptor usage. |         descriptor usage. | ||||||
|         """ |         """ | ||||||
|  |         if instance in self.instance_schemas: | ||||||
|  |             return self.instance_schemas[instance] | ||||||
|  | 
 | ||||||
|         self.view = instance |         self.view = instance | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|  |     def __set__(self, instance, other): | ||||||
|  |         self.instance_schemas[instance] = other | ||||||
|  |         if other is not None: | ||||||
|  |             other.view = instance | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def view(self): |     def view(self): | ||||||
|         """View property.""" |         """View property.""" | ||||||
|  | @ -189,6 +202,7 @@ class AutoSchema(ViewInspector): | ||||||
|         * `manual_fields`: list of `coreapi.Field` instances that |         * `manual_fields`: list of `coreapi.Field` instances that | ||||||
|             will be added to auto-generated fields, overwriting on `Field.name` |             will be added to auto-generated fields, overwriting on `Field.name` | ||||||
|         """ |         """ | ||||||
|  |         super(AutoSchema, self).__init__() | ||||||
|         if manual_fields is None: |         if manual_fields is None: | ||||||
|             manual_fields = [] |             manual_fields = [] | ||||||
|         self._manual_fields = manual_fields |         self._manual_fields = manual_fields | ||||||
|  | @ -455,6 +469,7 @@ class ManualSchema(ViewInspector): | ||||||
|         * `fields`: list of `coreapi.Field` instances. |         * `fields`: list of `coreapi.Field` instances. | ||||||
|         * `descripton`: String description for view. Optional. |         * `descripton`: String description for view. Optional. | ||||||
|         """ |         """ | ||||||
|  |         super(ManualSchema, self).__init__() | ||||||
|         assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" |         assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" | ||||||
|         self._fields = fields |         self._fields = fields | ||||||
|         self._description = description |         self._description = description | ||||||
|  | @ -474,9 +489,13 @@ class ManualSchema(ViewInspector): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DefaultSchema(object): | class DefaultSchema(ViewInspector): | ||||||
|     """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" |     """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|  |         result = super(DefaultSchema, self).__get__(instance, owner) | ||||||
|  |         if not isinstance(result, DefaultSchema): | ||||||
|  |             return result | ||||||
|  | 
 | ||||||
|         inspector_class = api_settings.DEFAULT_SCHEMA_CLASS |         inspector_class = api_settings.DEFAULT_SCHEMA_CLASS | ||||||
|         assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" |         assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" | ||||||
|         inspector = inspector_class() |         inspector = inspector_class() | ||||||
|  |  | ||||||
|  | @ -105,6 +105,10 @@ class ExampleViewSet(ModelViewSet): | ||||||
|         """Deletion description.""" |         """Deletion description.""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|  |     @action(detail=False, schema=None) | ||||||
|  |     def excluded_action(self, request): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|     def get_serializer(self, *args, **kwargs): |     def get_serializer(self, *args, **kwargs): | ||||||
|         assert self.request |         assert self.request | ||||||
|         assert self.action |         assert self.action | ||||||
|  | @ -735,6 +739,45 @@ class TestAutoSchema(TestCase): | ||||||
|         assert len(fields) == 2 |         assert len(fields) == 2 | ||||||
|         assert "my_extra_field" in [f.name for f in fields] |         assert "my_extra_field" in [f.name for f in fields] | ||||||
| 
 | 
 | ||||||
|  |     @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') | ||||||
|  |     def test_viewset_action_with_schema(self): | ||||||
|  |         class CustomViewSet(GenericViewSet): | ||||||
|  |             @action(detail=True, schema=AutoSchema(manual_fields=[ | ||||||
|  |                 coreapi.Field( | ||||||
|  |                     "my_extra_field", | ||||||
|  |                     required=True, | ||||||
|  |                     location="path", | ||||||
|  |                     schema=coreschema.String() | ||||||
|  |                 ), | ||||||
|  |             ])) | ||||||
|  |             def extra_action(self, pk, **kwargs): | ||||||
|  |                 pass | ||||||
|  | 
 | ||||||
|  |         router = SimpleRouter() | ||||||
|  |         router.register(r'detail', CustomViewSet, base_name='detail') | ||||||
|  | 
 | ||||||
|  |         generator = SchemaGenerator() | ||||||
|  |         view = generator.create_view(router.urls[0].callback, 'GET') | ||||||
|  |         link = view.schema.get_link('/a/url/{id}/', 'GET', '') | ||||||
|  |         fields = link.fields | ||||||
|  | 
 | ||||||
|  |         assert len(fields) == 2 | ||||||
|  |         assert "my_extra_field" in [f.name for f in fields] | ||||||
|  | 
 | ||||||
|  |     @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') | ||||||
|  |     def test_viewset_action_with_null_schema(self): | ||||||
|  |         class CustomViewSet(GenericViewSet): | ||||||
|  |             @action(detail=True, schema=None) | ||||||
|  |             def extra_action(self, pk, **kwargs): | ||||||
|  |                 pass | ||||||
|  | 
 | ||||||
|  |         router = SimpleRouter() | ||||||
|  |         router.register(r'detail', CustomViewSet, base_name='detail') | ||||||
|  | 
 | ||||||
|  |         generator = SchemaGenerator() | ||||||
|  |         view = generator.create_view(router.urls[0].callback, 'GET') | ||||||
|  |         assert view.schema is None | ||||||
|  | 
 | ||||||
|     @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') |     @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') | ||||||
|     def test_view_with_manual_schema(self): |     def test_view_with_manual_schema(self): | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user