diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 0fbe6e64b..69850c430 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -345,6 +345,29 @@ Included as the complete request body. Typically for `POST`, `PUT` and `PATCH` r These fields will normally correspond with views that use `ListSerializer` to validate the request input, or with file upload views. +#### `encoding` + +**"application/json"** + +JSON encoded request content. Corresponds to views using `JSONParser`. +Valid only if either one or more `location="form"` fields, or a single +`location="body"` field is included on the `Link`. + +**"multipart/form-data"** + +Multipart encoded request content. Corresponds to views using `MultiPartParser`. +Valid only if one or more `location="form"` fields is included on the `Link`. + +**"application/x-www-form-urlencoded"** + +URL encoded request content. Corresponds to views using `FormParser`. Valid +only if one or more `location="form"` fields is included on the `Link`. + +**"application/octet-stream"** + +Binary upload request content. Corresponds to views using `FileUploadParser`. +Valid only if a `location="body"` field is included on the `Link`. + #### `description` A short description of the meaning and intended usage of the input field. diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index d92fac5aa..c2d276250 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -5,12 +5,23 @@ from django.contrib.admindocs.views import simplify_regex from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.utils import six -from rest_framework import exceptions +from rest_framework import exceptions, serializers from rest_framework.compat import coreapi, uritemplate from rest_framework.request import clone_request from rest_framework.views import APIView +def as_query_fields(items): + """ + Take a list of Fields and plain strings. + Convert any pain strings into `location='query'` Field instances. + """ + return [ + item if isinstance(item, coreapi.Field) else coreapi.Field(name=item, required=False, location='query') + for item in items + ] + + def is_api_view(callback): """ Return `True` if the given view callback is a REST framework view/viewset. @@ -180,11 +191,47 @@ class SchemaGenerator(object): Return a `coreapi.Link` instance for the given endpoint. """ view = callback.cls() + fields = self.get_path_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view) fields += self.get_pagination_fields(path, method, callback, view) fields += self.get_filter_fields(path, method, callback, view) - return coreapi.Link(url=path, action=method.lower(), fields=fields) + + if fields: + encoding = self.get_encoding(path, method, callback, view) + else: + encoding = None + + return coreapi.Link( + url=path, + action=method.lower(), + encoding=encoding, + fields=fields + ) + + def get_encoding(self, path, method, callback, view): + """ + Return the 'encoding' parameter to use for a given endpoint. + """ + if method not in set(('POST', 'PUT', 'PATCH')): + return None + + # Core API supports the following request encodings over HTTP... + supported_media_types = set(( + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data', + )) + parser_classes = getattr(view, 'parser_classes', []) + for parser_class in parser_classes: + media_type = getattr(parser_class, 'media_type', None) + if media_type in supported_media_types: + return media_type + # Raw binary uploads are supported with "application/octet-stream" + if media_type == '*/*': + return 'application/octet-stream' + + return None def get_path_fields(self, path, method, callback, view): """ @@ -211,6 +258,13 @@ class SchemaGenerator(object): serializer_class = view.get_serializer_class() serializer = serializer_class() + + if isinstance(serializer, serializers.ListSerializer): + return coreapi.Field(name='data', location='body', required=True) + + if not isinstance(serializer, serializers.Serializer): + return [] + for field in serializer.fields.values(): if field.read_only: continue @@ -231,7 +285,7 @@ class SchemaGenerator(object): return [] paginator = view.pagination_class() - return paginator.get_fields(view) + return as_query_fields(paginator.get_fields(view)) def get_filter_fields(self, path, method, callback, view): if method != 'GET': @@ -245,5 +299,5 @@ class SchemaGenerator(object): fields = [] for filter_backend in view.filter_backends: - fields += filter_backend().get_fields(view) + fields += as_query_fields(filter_backend().get_fields(view)) return fields diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index f883b4925..e5b52ea5f 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -13,7 +13,7 @@ from django.utils import six, timezone from django.utils.encoding import force_text from django.utils.functional import Promise -from rest_framework.compat import total_seconds +from rest_framework.compat import coreapi, total_seconds class JSONEncoder(json.JSONEncoder): @@ -64,4 +64,9 @@ class JSONEncoder(json.JSONEncoder): pass elif hasattr(obj, '__iter__'): return tuple(item for item in obj) + elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): + raise RuntimeError( + 'Cannot return a coreapi object from a JSON view. ' + 'You should be using a schema renderer instead for this view.' + ) return super(JSONEncoder, self).default(obj) diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 000000000..0341d0df5 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,137 @@ +import unittest + +from django.conf.urls import include, url +from django.test import TestCase, override_settings + +from rest_framework import filters, pagination, permissions, serializers +from rest_framework.compat import coreapi +from rest_framework.routers import DefaultRouter +from rest_framework.test import APIClient +from rest_framework.viewsets import ModelViewSet + + +class MockUser(object): + def is_authenticated(self): + return True + + +class ExamplePagination(pagination.PageNumberPagination): + page_size = 100 + + +class ExampleSerializer(serializers.Serializer): + a = serializers.CharField(required=True) + b = serializers.CharField(required=False) + + +class ExampleViewSet(ModelViewSet): + pagination_class = ExamplePagination + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + filter_backends = [filters.OrderingFilter] + serializer_class = ExampleSerializer + + +router = DefaultRouter(schema_title='Example API') +router.register('example', ExampleViewSet, base_name='example') +urlpatterns = [ + url(r'^', include(router.urls)) +] + + +@unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(ROOT_URLCONF='tests.test_schemas') +class TestRouterGeneratedSchema(TestCase): + def test_anonymous_request(self): + client = APIClient() + response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json') + self.assertEqual(response.status_code, 200) + expected = coreapi.Document( + url='', + title='Example API', + content={ + 'example': { + 'list': coreapi.Link( + url='/example/', + action='get', + fields=[ + coreapi.Field('page', required=False, location='query'), + coreapi.Field('ordering', required=False, location='query') + ] + ), + 'retrieve': coreapi.Link( + url='/example/{pk}/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ) + } + } + ) + self.assertEqual(response.data, expected) + + def test_authenticated_request(self): + client = APIClient() + client.force_authenticate(MockUser()) + response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json') + self.assertEqual(response.status_code, 200) + expected = coreapi.Document( + url='', + title='Example API', + content={ + 'example': { + 'list': coreapi.Link( + url='/example/', + action='get', + fields=[ + coreapi.Field('page', required=False, location='query'), + coreapi.Field('ordering', required=False, location='query') + ] + ), + 'create': coreapi.Link( + url='/example/', + action='post', + encoding='application/json', + fields=[ + coreapi.Field('a', required=True, location='form'), + coreapi.Field('b', required=False, location='form') + ] + ), + 'retrieve': coreapi.Link( + url='/example/{pk}/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ), + 'update': coreapi.Link( + url='/example/{pk}/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=True, location='form'), + coreapi.Field('b', required=False, location='form') + ] + ), + 'partial_update': coreapi.Link( + url='/example/{pk}/', + action='patch', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=False, location='form'), + coreapi.Field('b', required=False, location='form') + ] + ), + 'destroy': coreapi.Link( + url='/example/{pk}/', + action='delete', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ) + } + } + ) + self.assertEqual(response.data, expected)