From 966934c0a2f5ec2e46f5dc5043709c862802f40d Mon Sep 17 00:00:00 2001 From: Nik Date: Thu, 11 Aug 2016 22:10:41 +0300 Subject: [PATCH] Add test for scheme allowed methods, fix categories for nested endpoints --- rest_framework/schemas.py | 5 ++-- tests/test_schemas.py | 51 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index c635c1fd6..26d5e5dac 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -64,7 +64,8 @@ class SchemaGenerator(object): def get_schema(self, request=None): if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) + endpoints = self.get_api_endpoints(self.patterns) + self.endpoints = self.add_categories(endpoints) links = [] for path, method, category, action, callback in self.endpoints: @@ -127,7 +128,7 @@ class SchemaGenerator(object): ) api_endpoints.extend(nested_endpoints) - return self.add_categories(api_endpoints) + return api_endpoints def add_categories(self, api_endpoints): """ diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 31ba97f2d..1b6ee3060 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -9,7 +9,7 @@ from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet @@ -62,6 +62,15 @@ class ExampleViewSet(ModelViewSet): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) +class RestrictiveViewSet(ModelViewSet): + permission_classes = [ForbidAll] + serializer_class = ExampleSerializer + + @detail_route(methods=['put'], permission_classes=[permissions.AllowAny]) + def allowed_action(self, request): + return super(RestrictiveViewSet, self).update(self, request) + + class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly] @@ -77,7 +86,14 @@ router.register('example', ExampleViewSet, base_name='example') urlpatterns = [ url(r'^', include(router.urls)) ] -urlpatterns2 = [ + +router = DefaultRouter(schema_title='Restrictive API' if coreapi else None) +router.register('example', RestrictiveViewSet, base_name='example') +urlpatterns_restrict = [ + url(r'^', include(router.urls)) +] + +urlpatterns_view = [ url(r'^example-view/$', ExampleView.as_view(), name='example-view') ] @@ -209,10 +225,39 @@ class TestRouterGeneratedSchema(TestCase): self.assertEqual(response.data, expected) +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaForRestrictedMethods(TestCase): + def test_resctricted_methods(self): + schema_generator = SchemaGenerator(title='Restrictive API', patterns=urlpatterns_restrict) + factory = APIRequestFactory() + from rest_framework.request import Request + mock_request = factory.get('/') + schema = schema_generator.get_schema(request=Request(mock_request)) + expected = coreapi.Document( + url='', + title='Restrictive API', + content={ + 'example': { + 'allowed_action': coreapi.Link( + url='/example/{pk}/allowed_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=True, location='form', description='A field description'), + coreapi.Field('b', required=False, location='form') + ] + ), + } + } + ) + self.assertEqual(schema, expected) + + @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): def test_view(self): - schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2) + schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns_view) schema = schema_generator.get_schema() expected = coreapi.Document( url='',