RequestsClient, CoreAPIClient

This commit is contained in:
Tom Christie 2016-10-04 13:57:34 +01:00
parent 4ad5256e88
commit 8044d38c21
6 changed files with 221 additions and 70 deletions

View File

@ -184,6 +184,99 @@ As usual CSRF validation will only apply to any session authenticated views. Th
---
# RequestsClient
REST framework also includes a client for interacting with your application
using the popular Python library, `requests`.
This exposes exactly the same interface as if you were using a requests session
directly.
client = RequestsClient()
response = client.get('http://testserver/users/')
Note that the requests client requires you to pass fully qualified URLs.
## Headers & Authentication
Custom headers and authentication credentials can be provided in the same way
as [when using a standard `requests.Session` instance](http://docs.python-requests.org/en/master/user/advanced/#session-objects).
from requests.auth import HTTPBasicAuth
client.auth = HTTPBasicAuth('user', 'pass')
client.headers.update({'x-test': 'true'})
## CSRF
If you're using `SessionAuthentication` then you'll need to include a CSRF token
for any `POST`, `PUT`, `PATCH` or `DELETE` requests.
You can do so by following the same flow that a JavaScript based client would use.
First make a `GET` request in order to obtain a CRSF token, then present that
token in the following request.
For example...
client = RequestsClient()
# Obtain a CSRF token.
response = client.get('/homepage/')
assert response.status_code == 200
csrftoken = response.cookies['csrftoken']
# Interact with the API.
response = client.post('/organisations/', json={
'name': 'MegaCorp',
'status': 'active'
}, headers={'X-CSRFToken': csrftoken})
assert response.status_code == 200
## Live tests
With careful usage both the `RequestsClient` and the `CoreAPIClient` provide
the ability to write test cases that can run either in development, or be run
directly against your staging server or production environment.
Using this style to create basic tests of a few core piece of functionality is
a powerful way to validate your live service. Doing so may require some careful
attention to setup and teardown to ensure that the tests run in a way that they
do not directly affect customer data.
---
# CoreAPIClient
The CoreAPIClient allows you to interact with your API using the Python
`coreapi` client library.
# Fetch the API schema
url = reverse('schema')
client = CoreAPIClient()
schema = client.get(url)
# Create a new organisation
params = {'name': 'MegaCorp', 'status': 'active'}
client.action(schema, ['organisations', 'create'], params)
# Ensure that the organisation exists in the listing
data = client.action(schema, ['organisations', 'list'])
assert(len(data) == 1)
assert(data == [{'name': 'MegaCorp', 'status': 'active'}])
## Headers & Authentication
Custom headers and authentication may be used with `CoreAPIClient` in a
similar way as with `RequestsClient`.
from requests.auth import HTTPBasicAuth
client = CoreAPIClient()
client.session.auth = HTTPBasicAuth('user', 'pass')
client.session.headers.update({'x-test': 'true'})
---
# Test cases
REST framework includes the following test case classes, that mirror the existing Django test case classes, but use `APIClient` instead of Django's default `Client`.

View File

@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, renderers, views
from rest_framework.compat import NoReverseMatch
from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
@ -281,7 +282,7 @@ class DefaultRouter(SimpleRouter):
include_root_view = True
include_format_suffixes = True
root_view_name = 'api-root'
default_schema_renderers = [renderers.CoreJSONRenderer]
default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer]
def __init__(self, *args, **kwargs):
if 'schema_renderers' in kwargs:

View File

@ -1,3 +1,4 @@
from collections import OrderedDict
from importlib import import_module
from django.conf import settings
@ -55,6 +56,18 @@ def is_custom_action(action):
])
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (path, method_priority)
class EndpointInspector(object):
"""
A class to determine the available API endpoints that a project exposes.
@ -101,6 +114,8 @@ class EndpointInspector(object):
)
api_endpoints.extend(nested_endpoints)
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
return api_endpoints
def get_path_from_regex(self, path_regex):
@ -183,7 +198,7 @@ class SchemaGenerator(object):
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = {}
links = OrderedDict()
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
if not self.has_view_permissions(view):

View File

@ -7,6 +7,7 @@ from __future__ import unicode_literals
import io
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.test import testcases
from django.test.client import Client as DjangoClient
@ -105,36 +106,46 @@ if requests is not None:
def close(self):
pass
class DjangoTestSession(requests.Session):
class NoExternalRequestsAdapter(requests.adapters.HTTPAdapter):
def send(self, request, *args, **kwargs):
msg = (
'RequestsClient refusing to make an outgoing network request '
'to "%s". Only "testserver" or hostnames in your ALLOWED_HOSTS '
'setting are valid.' % request.url
)
raise RuntimeError(msg)
class RequestsClient(requests.Session):
def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs)
super(RequestsClient, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
for hostname in hostnames:
if hostname == '*':
hostname = ''
self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter)
self.mount('http://', adapter)
self.mount('https://', adapter)
def request(self, method, url, *args, **kwargs):
if ':' not in url:
url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
return super(RequestsClient, self).request(method, url, *args, **kwargs)
else:
def RequestsClient(*args, **kwargs):
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
def get_requests_client():
assert requests is not None, 'requests must be installed'
return DjangoTestSession()
if coreapi is not None:
class CoreAPIClient(coreapi.Client):
def __init__(self, *args, **kwargs):
self._session = RequestsClient()
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
return super(CoreAPIClient, self).__init__(*args, **kwargs)
@property
def session(self):
return self._session
def get_api_client():
assert coreapi is not None, 'coreapi must be installed'
session = get_requests_client()
return coreapi.Client(transports=[
coreapi.transports.HTTPTransport(session=session)
])
else:
def CoreAPIClient(*args, **kwargs):
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
class APIRequestFactory(DjangoRequestFactory):

View File

@ -12,7 +12,7 @@ from rest_framework.compat import coreapi
from rest_framework.parsers import FileUploadParser
from rest_framework.renderers import CoreJSONRenderer
from rest_framework.response import Response
from rest_framework.test import APITestCase, get_api_client
from rest_framework.test import APITestCase, CoreAPIClient
from rest_framework.views import APIView
@ -22,6 +22,7 @@ def get_schema():
title='Example API',
content={
'simple_link': coreapi.Link('/example/', description='example link'),
'headers': coreapi.Link('/headers/'),
'location': {
'query': coreapi.Link('/example/', fields=[
coreapi.Field(name='example', description='example field')
@ -165,6 +166,19 @@ class TextView(APIView):
return HttpResponse('123', content_type='text/plain')
class HeadersView(APIView):
def get(self, request):
headers = {
key[5:].replace('_', '-'): value
for key, value in request.META.items()
if key.startswith('HTTP_')
}
return Response({
'method': request.method,
'headers': headers
})
urlpatterns = [
url(r'^$', SchemaView.as_view()),
url(r'^example/$', ListView.as_view()),
@ -172,6 +186,7 @@ urlpatterns = [
url(r'^upload/$', UploadView.as_view()),
url(r'^download/$', DownloadView.as_view()),
url(r'^text/$', TextView.as_view()),
url(r'^headers/$', HeadersView.as_view()),
]
@ -179,7 +194,7 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_api_client')
class APIClientTests(APITestCase):
def test_api_client(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
assert schema.title == 'Example API'
assert schema.url == 'https://api.example.com/'
@ -193,7 +208,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_query_params(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': 123})
expected = {
@ -202,8 +217,15 @@ class APIClientTests(APITestCase):
}
assert data == expected
def test_session_headers(self):
client = CoreAPIClient()
client.session.headers.update({'X-Custom-Header': 'foo'})
schema = client.get('http://api.example.com/')
data = client.action(schema, ['headers'])
assert data['headers']['X-CUSTOM-HEADER'] == 'foo'
def test_query_params_with_multiple_values(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
expected = {
@ -213,7 +235,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_form_params(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'form'], params={'example': 123})
expected = {
@ -226,7 +248,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_body_params(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'body'], params={'example': 123})
expected = {
@ -239,7 +261,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_path_params(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'path'], params={'id': 123})
expected = {
@ -250,7 +272,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_multipart_encoding(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
temp = tempfile.NamedTemporaryFile()
@ -272,7 +294,7 @@ class APIClientTests(APITestCase):
def test_multipart_encoding_no_file(self):
# When no file is included, multipart encoding should still be used.
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
@ -287,7 +309,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_multipart_encoding_multiple_values(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
@ -305,7 +327,7 @@ class APIClientTests(APITestCase):
# Test for `coreapi.utils.File` support.
from coreapi.utils import File
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
example = File(name='example.txt', content='123')
@ -323,7 +345,7 @@ class APIClientTests(APITestCase):
def test_multipart_encoding_in_body(self):
from coreapi.utils import File
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
@ -341,7 +363,7 @@ class APIClientTests(APITestCase):
# URLencoded
def test_urlencoded_encoding(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
expected = {
@ -354,7 +376,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_urlencoded_encoding_multiple_values(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
expected = {
@ -367,7 +389,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_urlencoded_encoding_in_body(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
expected = {
@ -382,7 +404,7 @@ class APIClientTests(APITestCase):
# Raw uploads
def test_raw_upload(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
temp = tempfile.NamedTemporaryFile()
@ -403,7 +425,7 @@ class APIClientTests(APITestCase):
def test_raw_upload_string_file_content(self):
from coreapi.utils import File
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
example = File('example.txt', '123')
@ -419,7 +441,7 @@ class APIClientTests(APITestCase):
def test_raw_upload_explicit_content_type(self):
from coreapi.utils import File
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
example = File('example.txt', '123', 'text/html')
@ -435,7 +457,7 @@ class APIClientTests(APITestCase):
# Responses
def test_text_response(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['response', 'text'])
@ -444,7 +466,7 @@ class APIClientTests(APITestCase):
assert data == expected
def test_download_response(self):
client = get_api_client()
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['response', 'download'])

View File

@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests
from rest_framework.response import Response
from rest_framework.test import APITestCase, get_requests_client
from rest_framework.test import APITestCase, RequestsClient
from rest_framework.views import APIView
@ -92,10 +92,10 @@ class AuthView(APIView):
urlpatterns = [
url(r'^$', Root.as_view()),
url(r'^headers/$', HeadersView.as_view()),
url(r'^session/$', SessionView.as_view()),
url(r'^auth/$', AuthView.as_view()),
url(r'^$', Root.as_view(), name='root'),
url(r'^headers/$', HeadersView.as_view(), name='headers'),
url(r'^session/$', SessionView.as_view(), name='session'),
url(r'^auth/$', AuthView.as_view(), name='auth'),
]
@ -103,8 +103,8 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_request(self):
client = get_requests_client()
response = client.get('/')
client = RequestsClient()
response = client.get('http://testserver/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -114,8 +114,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_get_request_query_params_in_url(self):
client = get_requests_client()
response = client.get('/?key=value')
client = RequestsClient()
response = client.get('http://testserver/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -125,8 +125,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_get_request_query_params_by_kwarg(self):
client = get_requests_client()
response = client.get('/', params={'key': 'value'})
client = RequestsClient()
response = client.get('http://testserver/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -136,16 +136,25 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_get_with_headers(self):
client = get_requests_client()
response = client.get('/headers/', headers={'User-Agent': 'example'})
client = RequestsClient()
response = client.get('http://testserver/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'
def test_get_with_session_headers(self):
client = RequestsClient()
client.headers.update({'User-Agent': 'example'})
response = client.get('http://testserver/headers/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'
def test_post_form_request(self):
client = get_requests_client()
response = client.post('/', data={'key': 'value'})
client = RequestsClient()
response = client.post('http://testserver/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -158,8 +167,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_post_json_request(self):
client = get_requests_client()
response = client.post('/', json={'key': 'value'})
client = RequestsClient()
response = client.post('http://testserver/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -172,11 +181,11 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_post_multipart_request(self):
client = get_requests_client()
client = RequestsClient()
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = client.post('/', files=files)
response = client.post('http://testserver/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -189,20 +198,20 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
def test_session(self):
client = get_requests_client()
response = client.get('/session/')
client = RequestsClient()
response = client.get('http://testserver/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {}
assert response.json() == expected
response = client.post('/session/', json={'example': 'abc'})
response = client.post('http://testserver/session/', json={'example': 'abc'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected
response = client.get('/session/')
response = client.get('http://testserver/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
@ -210,8 +219,8 @@ class RequestsClientTests(APITestCase):
def test_auth(self):
# Confirm session is not authenticated
client = get_requests_client()
response = client.get('/auth/')
client = RequestsClient()
response = client.get('http://testserver/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
@ -226,7 +235,7 @@ class RequestsClientTests(APITestCase):
user.save()
# Perform a login
response = client.post('/auth/', json={
response = client.post('http://testserver/auth/', json={
'username': 'tom',
'password': 'password'
}, headers={'X-CSRFToken': csrftoken})
@ -238,7 +247,7 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected
# Confirm session is authenticated
response = client.get('/auth/')
response = client.get('http://testserver/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {