diff --git a/rest_framework/test.py b/rest_framework/test.py index e17c19a43..492edac50 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -16,7 +16,7 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from rest_framework.compat import requests +from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -126,6 +126,14 @@ def get_requests_client(): return DjangoTestSession() +def get_api_client(): + assert coreapi is not None, 'coreapi must be installed' + session = get_requests_client() + return coreapi.Client(transports=[ + coreapi.transports.HTTPTransport(session=session) + ]) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 000000000..498941297 --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,33 @@ +from __future__ import unicode_literals + +import unittest + +from django.conf.urls import url +from django.test import override_settings + +from rest_framework.compat import coreapi +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_api_client +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({ + 'hello': 'world', + }) + + +urlpatterns = [ + url(r'^$', Root.as_view()), +] + + +@unittest.skipUnless(coreapi, 'coreapi not installed') +@override_settings(ROOT_URLCONF='tests.test_api_client') +class APIClientTests(APITestCase): + def test_api_client(self): + client = get_api_client() + schema = client.get('/') + data = client.action(schema, ['echo']) + assert data == {'hello': 'world'}