mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 22:04:48 +03:00
Add integration tests for schema generation
This commit is contained in:
parent
8519b4e24c
commit
2f5c9748d3
|
@ -345,6 +345,29 @@ Included as the complete request body. Typically for `POST`, `PUT` and `PATCH` r
|
||||||
|
|
||||||
These fields will normally correspond with views that use `ListSerializer` to validate the request input, or with file upload views.
|
These fields will normally correspond with views that use `ListSerializer` to validate the request input, or with file upload views.
|
||||||
|
|
||||||
|
#### `encoding`
|
||||||
|
|
||||||
|
**"application/json"**
|
||||||
|
|
||||||
|
JSON encoded request content. Corresponds to views using `JSONParser`.
|
||||||
|
Valid only if either one or more `location="form"` fields, or a single
|
||||||
|
`location="body"` field is included on the `Link`.
|
||||||
|
|
||||||
|
**"multipart/form-data"**
|
||||||
|
|
||||||
|
Multipart encoded request content. Corresponds to views using `MultiPartParser`.
|
||||||
|
Valid only if one or more `location="form"` fields is included on the `Link`.
|
||||||
|
|
||||||
|
**"application/x-www-form-urlencoded"**
|
||||||
|
|
||||||
|
URL encoded request content. Corresponds to views using `FormParser`. Valid
|
||||||
|
only if one or more `location="form"` fields is included on the `Link`.
|
||||||
|
|
||||||
|
**"application/octet-stream"**
|
||||||
|
|
||||||
|
Binary upload request content. Corresponds to views using `FileUploadParser`.
|
||||||
|
Valid only if a `location="body"` field is included on the `Link`.
|
||||||
|
|
||||||
#### `description`
|
#### `description`
|
||||||
|
|
||||||
A short description of the meaning and intended usage of the input field.
|
A short description of the meaning and intended usage of the input field.
|
||||||
|
|
|
@ -5,12 +5,23 @@ from django.contrib.admindocs.views import simplify_regex
|
||||||
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
from rest_framework import exceptions
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.compat import coreapi, uritemplate
|
from rest_framework.compat import coreapi, uritemplate
|
||||||
from rest_framework.request import clone_request
|
from rest_framework.request import clone_request
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
|
def as_query_fields(items):
|
||||||
|
"""
|
||||||
|
Take a list of Fields and plain strings.
|
||||||
|
Convert any pain strings into `location='query'` Field instances.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
item if isinstance(item, coreapi.Field) else coreapi.Field(name=item, required=False, location='query')
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_api_view(callback):
|
def is_api_view(callback):
|
||||||
"""
|
"""
|
||||||
Return `True` if the given view callback is a REST framework view/viewset.
|
Return `True` if the given view callback is a REST framework view/viewset.
|
||||||
|
@ -180,11 +191,47 @@ class SchemaGenerator(object):
|
||||||
Return a `coreapi.Link` instance for the given endpoint.
|
Return a `coreapi.Link` instance for the given endpoint.
|
||||||
"""
|
"""
|
||||||
view = callback.cls()
|
view = callback.cls()
|
||||||
|
|
||||||
fields = self.get_path_fields(path, method, callback, view)
|
fields = self.get_path_fields(path, method, callback, view)
|
||||||
fields += self.get_serializer_fields(path, method, callback, view)
|
fields += self.get_serializer_fields(path, method, callback, view)
|
||||||
fields += self.get_pagination_fields(path, method, callback, view)
|
fields += self.get_pagination_fields(path, method, callback, view)
|
||||||
fields += self.get_filter_fields(path, method, callback, view)
|
fields += self.get_filter_fields(path, method, callback, view)
|
||||||
return coreapi.Link(url=path, action=method.lower(), fields=fields)
|
|
||||||
|
if fields:
|
||||||
|
encoding = self.get_encoding(path, method, callback, view)
|
||||||
|
else:
|
||||||
|
encoding = None
|
||||||
|
|
||||||
|
return coreapi.Link(
|
||||||
|
url=path,
|
||||||
|
action=method.lower(),
|
||||||
|
encoding=encoding,
|
||||||
|
fields=fields
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_encoding(self, path, method, callback, view):
|
||||||
|
"""
|
||||||
|
Return the 'encoding' parameter to use for a given endpoint.
|
||||||
|
"""
|
||||||
|
if method not in set(('POST', 'PUT', 'PATCH')):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Core API supports the following request encodings over HTTP...
|
||||||
|
supported_media_types = set((
|
||||||
|
'application/json',
|
||||||
|
'application/x-www-form-urlencoded',
|
||||||
|
'multipart/form-data',
|
||||||
|
))
|
||||||
|
parser_classes = getattr(view, 'parser_classes', [])
|
||||||
|
for parser_class in parser_classes:
|
||||||
|
media_type = getattr(parser_class, 'media_type', None)
|
||||||
|
if media_type in supported_media_types:
|
||||||
|
return media_type
|
||||||
|
# Raw binary uploads are supported with "application/octet-stream"
|
||||||
|
if media_type == '*/*':
|
||||||
|
return 'application/octet-stream'
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_path_fields(self, path, method, callback, view):
|
def get_path_fields(self, path, method, callback, view):
|
||||||
"""
|
"""
|
||||||
|
@ -211,6 +258,13 @@ class SchemaGenerator(object):
|
||||||
|
|
||||||
serializer_class = view.get_serializer_class()
|
serializer_class = view.get_serializer_class()
|
||||||
serializer = serializer_class()
|
serializer = serializer_class()
|
||||||
|
|
||||||
|
if isinstance(serializer, serializers.ListSerializer):
|
||||||
|
return coreapi.Field(name='data', location='body', required=True)
|
||||||
|
|
||||||
|
if not isinstance(serializer, serializers.Serializer):
|
||||||
|
return []
|
||||||
|
|
||||||
for field in serializer.fields.values():
|
for field in serializer.fields.values():
|
||||||
if field.read_only:
|
if field.read_only:
|
||||||
continue
|
continue
|
||||||
|
@ -231,7 +285,7 @@ class SchemaGenerator(object):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
paginator = view.pagination_class()
|
paginator = view.pagination_class()
|
||||||
return paginator.get_fields(view)
|
return as_query_fields(paginator.get_fields(view))
|
||||||
|
|
||||||
def get_filter_fields(self, path, method, callback, view):
|
def get_filter_fields(self, path, method, callback, view):
|
||||||
if method != 'GET':
|
if method != 'GET':
|
||||||
|
@ -245,5 +299,5 @@ class SchemaGenerator(object):
|
||||||
|
|
||||||
fields = []
|
fields = []
|
||||||
for filter_backend in view.filter_backends:
|
for filter_backend in view.filter_backends:
|
||||||
fields += filter_backend().get_fields(view)
|
fields += as_query_fields(filter_backend().get_fields(view))
|
||||||
return fields
|
return fields
|
||||||
|
|
|
@ -13,7 +13,7 @@ from django.utils import six, timezone
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
from django.utils.functional import Promise
|
from django.utils.functional import Promise
|
||||||
|
|
||||||
from rest_framework.compat import total_seconds
|
from rest_framework.compat import coreapi, total_seconds
|
||||||
|
|
||||||
|
|
||||||
class JSONEncoder(json.JSONEncoder):
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
@ -64,4 +64,9 @@ class JSONEncoder(json.JSONEncoder):
|
||||||
pass
|
pass
|
||||||
elif hasattr(obj, '__iter__'):
|
elif hasattr(obj, '__iter__'):
|
||||||
return tuple(item for item in obj)
|
return tuple(item for item in obj)
|
||||||
|
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)):
|
||||||
|
raise RuntimeError(
|
||||||
|
'Cannot return a coreapi object from a JSON view. '
|
||||||
|
'You should be using a schema renderer instead for this view.'
|
||||||
|
)
|
||||||
return super(JSONEncoder, self).default(obj)
|
return super(JSONEncoder, self).default(obj)
|
||||||
|
|
137
tests/test_schemas.py
Normal file
137
tests/test_schemas.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from django.conf.urls import include, url
|
||||||
|
from django.test import TestCase, override_settings
|
||||||
|
|
||||||
|
from rest_framework import filters, pagination, permissions, serializers
|
||||||
|
from rest_framework.compat import coreapi
|
||||||
|
from rest_framework.routers import DefaultRouter
|
||||||
|
from rest_framework.test import APIClient
|
||||||
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
|
|
||||||
|
class MockUser(object):
|
||||||
|
def is_authenticated(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplePagination(pagination.PageNumberPagination):
|
||||||
|
page_size = 100
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleSerializer(serializers.Serializer):
|
||||||
|
a = serializers.CharField(required=True)
|
||||||
|
b = serializers.CharField(required=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleViewSet(ModelViewSet):
|
||||||
|
pagination_class = ExamplePagination
|
||||||
|
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||||
|
filter_backends = [filters.OrderingFilter]
|
||||||
|
serializer_class = ExampleSerializer
|
||||||
|
|
||||||
|
|
||||||
|
router = DefaultRouter(schema_title='Example API')
|
||||||
|
router.register('example', ExampleViewSet, base_name='example')
|
||||||
|
urlpatterns = [
|
||||||
|
url(r'^', include(router.urls))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||||
|
@override_settings(ROOT_URLCONF='tests.test_schemas')
|
||||||
|
class TestRouterGeneratedSchema(TestCase):
|
||||||
|
def test_anonymous_request(self):
|
||||||
|
client = APIClient()
|
||||||
|
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
expected = coreapi.Document(
|
||||||
|
url='',
|
||||||
|
title='Example API',
|
||||||
|
content={
|
||||||
|
'example': {
|
||||||
|
'list': coreapi.Link(
|
||||||
|
url='/example/',
|
||||||
|
action='get',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('page', required=False, location='query'),
|
||||||
|
coreapi.Field('ordering', required=False, location='query')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'retrieve': coreapi.Link(
|
||||||
|
url='/example/{pk}/',
|
||||||
|
action='get',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('pk', required=True, location='path')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.data, expected)
|
||||||
|
|
||||||
|
def test_authenticated_request(self):
|
||||||
|
client = APIClient()
|
||||||
|
client.force_authenticate(MockUser())
|
||||||
|
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
expected = coreapi.Document(
|
||||||
|
url='',
|
||||||
|
title='Example API',
|
||||||
|
content={
|
||||||
|
'example': {
|
||||||
|
'list': coreapi.Link(
|
||||||
|
url='/example/',
|
||||||
|
action='get',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('page', required=False, location='query'),
|
||||||
|
coreapi.Field('ordering', required=False, location='query')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'create': coreapi.Link(
|
||||||
|
url='/example/',
|
||||||
|
action='post',
|
||||||
|
encoding='application/json',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('a', required=True, location='form'),
|
||||||
|
coreapi.Field('b', required=False, location='form')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'retrieve': coreapi.Link(
|
||||||
|
url='/example/{pk}/',
|
||||||
|
action='get',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('pk', required=True, location='path')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'update': coreapi.Link(
|
||||||
|
url='/example/{pk}/',
|
||||||
|
action='put',
|
||||||
|
encoding='application/json',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('pk', required=True, location='path'),
|
||||||
|
coreapi.Field('a', required=True, location='form'),
|
||||||
|
coreapi.Field('b', required=False, location='form')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'partial_update': coreapi.Link(
|
||||||
|
url='/example/{pk}/',
|
||||||
|
action='patch',
|
||||||
|
encoding='application/json',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('pk', required=True, location='path'),
|
||||||
|
coreapi.Field('a', required=False, location='form'),
|
||||||
|
coreapi.Field('b', required=False, location='form')
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'destroy': coreapi.Link(
|
||||||
|
url='/example/{pk}/',
|
||||||
|
action='delete',
|
||||||
|
fields=[
|
||||||
|
coreapi.Field('pk', required=True, location='path')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.data, expected)
|
Loading…
Reference in New Issue
Block a user