Add integration tests for schema generation

This commit is contained in:
Tom Christie 2016-06-22 14:11:37 +01:00
parent 8519b4e24c
commit 2f5c9748d3
4 changed files with 224 additions and 5 deletions

View File

@ -345,6 +345,29 @@ Included as the complete request body. Typically for `POST`, `PUT` and `PATCH` r
These fields will normally correspond with views that use `ListSerializer` to validate the request input, or with file upload views.
#### `encoding`
**"application/json"**
JSON encoded request content. Corresponds to views using `JSONParser`.
Valid only if either one or more `location="form"` fields, or a single
`location="body"` field is included on the `Link`.
**"multipart/form-data"**
Multipart encoded request content. Corresponds to views using `MultiPartParser`.
Valid only if one or more `location="form"` fields is included on the `Link`.
**"application/x-www-form-urlencoded"**
URL encoded request content. Corresponds to views using `FormParser`. Valid
only if one or more `location="form"` fields is included on the `Link`.
**"application/octet-stream"**
Binary upload request content. Corresponds to views using `FileUploadParser`.
Valid only if a `location="body"` field is included on the `Link`.
#### `description`
A short description of the meaning and intended usage of the input field.

View File

@ -5,12 +5,23 @@ from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six
from rest_framework import exceptions
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, uritemplate
from rest_framework.request import clone_request
from rest_framework.views import APIView
def as_query_fields(items):
"""
Take a list of Fields and plain strings.
Convert any pain strings into `location='query'` Field instances.
"""
return [
item if isinstance(item, coreapi.Field) else coreapi.Field(name=item, required=False, location='query')
for item in items
]
def is_api_view(callback):
"""
Return `True` if the given view callback is a REST framework view/viewset.
@ -180,11 +191,47 @@ class SchemaGenerator(object):
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)
fields += self.get_filter_fields(path, method, callback, view)
return coreapi.Link(url=path, action=method.lower(), fields=fields)
if fields:
encoding = self.get_encoding(path, method, callback, view)
else:
encoding = None
return coreapi.Link(
url=path,
action=method.lower(),
encoding=encoding,
fields=fields
)
def get_encoding(self, path, method, callback, view):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
if method not in set(('POST', 'PUT', 'PATCH')):
return None
# Core API supports the following request encodings over HTTP...
supported_media_types = set((
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
))
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
def get_path_fields(self, path, method, callback, view):
"""
@ -211,6 +258,13 @@ class SchemaGenerator(object):
serializer_class = view.get_serializer_class()
serializer = serializer_class()
if isinstance(serializer, serializers.ListSerializer):
return coreapi.Field(name='data', location='body', required=True)
if not isinstance(serializer, serializers.Serializer):
return []
for field in serializer.fields.values():
if field.read_only:
continue
@ -231,7 +285,7 @@ class SchemaGenerator(object):
return []
paginator = view.pagination_class()
return paginator.get_fields(view)
return as_query_fields(paginator.get_fields(view))
def get_filter_fields(self, path, method, callback, view):
if method != 'GET':
@ -245,5 +299,5 @@ class SchemaGenerator(object):
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_fields(view)
fields += as_query_fields(filter_backend().get_fields(view))
return fields

View File

@ -13,7 +13,7 @@ from django.utils import six, timezone
from django.utils.encoding import force_text
from django.utils.functional import Promise
from rest_framework.compat import total_seconds
from rest_framework.compat import coreapi, total_seconds
class JSONEncoder(json.JSONEncoder):
@ -64,4 +64,9 @@ class JSONEncoder(json.JSONEncoder):
pass
elif hasattr(obj, '__iter__'):
return tuple(item for item in obj)
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)):
raise RuntimeError(
'Cannot return a coreapi object from a JSON view. '
'You should be using a schema renderer instead for this view.'
)
return super(JSONEncoder, self).default(obj)

137
tests/test_schemas.py Normal file
View File

@ -0,0 +1,137 @@
import unittest
from django.conf.urls import include, url
from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi
from rest_framework.routers import DefaultRouter
from rest_framework.test import APIClient
from rest_framework.viewsets import ModelViewSet
class MockUser(object):
def is_authenticated(self):
return True
class ExamplePagination(pagination.PageNumberPagination):
page_size = 100
class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True)
b = serializers.CharField(required=False)
class ExampleViewSet(ModelViewSet):
pagination_class = ExamplePagination
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [filters.OrderingFilter]
serializer_class = ExampleSerializer
router = DefaultRouter(schema_title='Example API')
router.register('example', ExampleViewSet, base_name='example')
urlpatterns = [
url(r'^', include(router.urls))
]
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas')
class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self):
client = APIClient()
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
self.assertEqual(response.status_code, 200)
expected = coreapi.Document(
url='',
title='Example API',
content={
'example': {
'list': coreapi.Link(
url='/example/',
action='get',
fields=[
coreapi.Field('page', required=False, location='query'),
coreapi.Field('ordering', required=False, location='query')
]
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
}
}
)
self.assertEqual(response.data, expected)
def test_authenticated_request(self):
client = APIClient()
client.force_authenticate(MockUser())
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
self.assertEqual(response.status_code, 200)
expected = coreapi.Document(
url='',
title='Example API',
content={
'example': {
'list': coreapi.Link(
url='/example/',
action='get',
fields=[
coreapi.Field('page', required=False, location='query'),
coreapi.Field('ordering', required=False, location='query')
]
),
'create': coreapi.Link(
url='/example/',
action='post',
encoding='application/json',
fields=[
coreapi.Field('a', required=True, location='form'),
coreapi.Field('b', required=False, location='form')
]
),
'retrieve': coreapi.Link(
url='/example/{pk}/',
action='get',
fields=[
coreapi.Field('pk', required=True, location='path')
]
),
'update': coreapi.Link(
url='/example/{pk}/',
action='put',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('a', required=True, location='form'),
coreapi.Field('b', required=False, location='form')
]
),
'partial_update': coreapi.Link(
url='/example/{pk}/',
action='patch',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('a', required=False, location='form'),
coreapi.Field('b', required=False, location='form')
]
),
'destroy': coreapi.Link(
url='/example/{pk}/',
action='delete',
fields=[
coreapi.Field('pk', required=True, location='path')
]
)
}
}
)
self.assertEqual(response.data, expected)