diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 41dc82da1..a2e54347f 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -4,8 +4,9 @@ from django.conf import settings from django.contrib.admindocs.views import simplify_regex from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.utils import six +from django.utils.encoding import force_text -from rest_framework import exceptions, serializers +from rest_framework import exceptions, serializers, viewsets from rest_framework.compat import coreapi, uritemplate, urlparse from rest_framework.request import clone_request from rest_framework.views import APIView @@ -175,18 +176,16 @@ class SchemaGenerator(object): Return a tuple of strings, indicating the identity to use for a given endpoint. eg. ('users', 'list'). """ - category = None + category = [] for item in path.strip('/').split('/'): if '{' in item: - break - category = item + continue + category.append(item.capitalize()) actions = getattr(callback, 'actions', self.default_mapping) action = actions[method.lower()] - if category: - return (category, action) - return (action,) + return (' '.join(category), action) # Methods for generating each individual `Link` instance... @@ -206,10 +205,18 @@ class SchemaGenerator(object): else: encoding = None + if isinstance(view, viewsets.GenericViewSet): + actions = getattr(callback, 'actions', self.default_mapping) + action = actions[method.lower()] + view_fn = getattr(callback.cls, action, None) + else: + view_fn = getattr(callback.cls, method.lower(), None) + return coreapi.Link( url=urlparse.urljoin(self.url, path), action=method.lower(), encoding=encoding, + description=view_fn.__doc__ if view_fn else '', fields=fields ) @@ -239,10 +246,16 @@ class SchemaGenerator(object): Return a list of `coreapi.Field` instances corresponding to any templated path variables. """ + path_descriptions = getattr(view, 'path_fields_descriptions', {}) + fields = [] for variable in uritemplate.variables(path): - field = coreapi.Field(name=variable, location='path', required=True) + field = coreapi.Field(name=variable, + location='path', + required=True, + description=path_descriptions.get(variable, ''), + ) fields.append(field) return fields @@ -258,8 +271,6 @@ class SchemaGenerator(object): if not hasattr(view, 'get_serializer_class'): return [] - fields = [] - serializer_class = view.get_serializer_class() serializer = serializer_class() @@ -269,11 +280,17 @@ class SchemaGenerator(object): if not isinstance(serializer, serializers.Serializer): return [] + fields = [] for field in serializer.fields.values(): if field.read_only: continue required = field.required and method != 'PATCH' - field = coreapi.Field(name=field.source, location='form', required=required) + field = coreapi.Field( + name=field.source, + location='form', + required=required, + description=force_text(field.help_text), + ) fields.append(field) return fields diff --git a/tests/test_schemas.py b/tests/test_schemas.py index a32b8a117..155ad1ff1 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -6,7 +6,7 @@ from django.test import TestCase, override_settings from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi from rest_framework.response import Response -from rest_framework.routers import DefaultRouter +from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.schemas import SchemaGenerator from rest_framework.test import APIClient from rest_framework.views import APIView @@ -23,8 +23,8 @@ class ExamplePagination(pagination.PageNumberPagination): class ExampleSerializer(serializers.Serializer): - a = serializers.CharField(required=True) - b = serializers.CharField(required=False) + a = serializers.CharField(required=True, help_text='About a') + b = serializers.CharField(required=False, help_text='About b') class ExampleViewSet(ModelViewSet): @@ -36,8 +36,12 @@ class ExampleViewSet(ModelViewSet): class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly] + path_fields_descriptions = { + 'example_id': 'Description of example_id path parameter', + } def get(self, request, *args, **kwargs): + """get documentation""" return Response() def post(self, request, *args, **kwargs): @@ -54,6 +58,14 @@ urlpatterns2 = [ ] +router = SimpleRouter() +router.register('example', ExampleViewSet, base_name='example') +urlpatterns3 = [ + url(r'^', include(router.urls)), + url(r'^(?P\w+)/example-view/$', ExampleView.as_view(), name='example-view') +] + + @unittest.skipUnless(coreapi, 'coreapi is not installed') @override_settings(ROOT_URLCONF='tests.test_schemas') class TestRouterGeneratedSchema(TestCase): @@ -65,7 +77,7 @@ class TestRouterGeneratedSchema(TestCase): url='', title='Example API', content={ - 'example': { + 'Example': { 'list': coreapi.Link( url='/example/', action='get', @@ -95,7 +107,7 @@ class TestRouterGeneratedSchema(TestCase): url='', title='Example API', content={ - 'example': { + 'Example': { 'list': coreapi.Link( url='/example/', action='get', @@ -109,8 +121,8 @@ class TestRouterGeneratedSchema(TestCase): action='post', encoding='application/json', fields=[ - coreapi.Field('a', required=True, location='form'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=True, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') ] ), 'retrieve': coreapi.Link( @@ -126,8 +138,8 @@ class TestRouterGeneratedSchema(TestCase): encoding='application/json', fields=[ coreapi.Field('pk', required=True, location='path'), - coreapi.Field('a', required=True, location='form'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=True, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') ] ), 'partial_update': coreapi.Link( @@ -136,8 +148,8 @@ class TestRouterGeneratedSchema(TestCase): encoding='application/json', fields=[ coreapi.Field('pk', required=True, location='path'), - coreapi.Field('a', required=False, location='form'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=False, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') ] ), 'destroy': coreapi.Link( @@ -162,7 +174,7 @@ class TestSchemaGenerator(TestCase): url='', title='Test View', content={ - 'example-view': { + 'Example-view': { 'create': coreapi.Link( url='/example-view/', action='post', @@ -171,9 +183,94 @@ class TestSchemaGenerator(TestCase): 'read': coreapi.Link( url='/example-view/', action='get', + description='get documentation', fields=[] ) } } ) self.assertEquals(schema, expected) + + +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaAndSubSchemaGenerator(TestCase): + def test_view(self): + schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns3) + schema = schema_generator.get_schema() + expected = coreapi.Document( + url='', + title='Test View', + content={ + 'Example': { + 'list': coreapi.Link( + url='/example/', + action='get', + fields=[ + coreapi.Field('page', required=False, location='query'), + coreapi.Field('ordering', required=False, location='query') + ] + ), + 'create': coreapi.Link( + url='/example/', + action='post', + encoding='application/json', + fields=[ + coreapi.Field('a', required=True, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') + ] + ), + 'retrieve': coreapi.Link( + url='/example/{pk}/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ), + 'update': coreapi.Link( + url='/example/{pk}/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=True, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') + ] + ), + 'partial_update': coreapi.Link( + url='/example/{pk}/', + action='patch', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=False, location='form', description='About a'), + coreapi.Field('b', required=False, location='form', description='About b') + ] + ), + 'destroy': coreapi.Link( + url='/example/{pk}/', + action='delete', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ) + }, + 'Example-view': { + 'create': coreapi.Link( + url='/{example_id}/example-view/', + action='post', + fields=[ + coreapi.Field('example_id', required=True, location='path', description='Description of example_id path parameter') + ] + ), + 'read': coreapi.Link( + url='/{example_id}/example-view/', + action='get', + description='get documentation', + fields=[ + coreapi.Field('example_id', required=True, location='path', description='Description of example_id path parameter') + ] + ) + }, + } + ) + self.assertEquals(schema, expected)