diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index b5d2e0254..688deec88 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -195,6 +195,8 @@ class SchemaGenerator(object): Return a `coreapi.Link` instance for the given endpoint. """ view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) fields = self.get_path_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 7687448c4..2f440c567 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -97,6 +97,7 @@ class ViewSetMixin(object): # generation can pick out these bits of information from a # resolved URL. view.cls = cls + view.initkwargs = initkwargs view.suffix = initkwargs.get('suffix', None) view.actions = actions return csrf_exempt(view) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index a32b8a117..6c02c9d23 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,6 +5,7 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi +from rest_framework.decorators import detail_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator @@ -27,12 +28,21 @@ class ExampleSerializer(serializers.Serializer): b = serializers.CharField(required=False) +class AnotherSerializer(serializers.Serializer): + c = serializers.CharField(required=True) + d = serializers.CharField(required=False) + + class ExampleViewSet(ModelViewSet): pagination_class = ExamplePagination permission_classes = [permissions.IsAuthenticatedOrReadOnly] filter_backends = [filters.OrderingFilter] serializer_class = ExampleSerializer + @detail_route(methods=['post'], serializer_class=AnotherSerializer) + def custom_action(self, request, pk): + return super(ExampleSerializer, self).retrieve(self, request) + class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly] @@ -120,6 +130,16 @@ class TestRouterGeneratedSchema(TestCase): coreapi.Field('pk', required=True, location='path') ] ), + 'custom_action': coreapi.Link( + url='/example/{pk}/custom_action/', + action='post', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('c', required=True, location='form'), + coreapi.Field('d', required=False, location='form'), + ] + ), 'update': coreapi.Link( url='/example/{pk}/', action='put',