diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 18f9e19e9..fdd60b7f4 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -184,6 +184,99 @@ As usual CSRF validation will only apply to any session authenticated views. Th --- +# RequestsClient + +REST framework also includes a client for interacting with your application +using the popular Python library, `requests`. + +This exposes exactly the same interface as if you were using a requests session +directly. + + client = RequestsClient() + response = client.get('http://testserver/users/') + +Note that the requests client requires you to pass fully qualified URLs. + +## Headers & Authentication + +Custom headers and authentication credentials can be provided in the same way +as [when using a standard `requests.Session` instance](http://docs.python-requests.org/en/master/user/advanced/#session-objects). + + from requests.auth import HTTPBasicAuth + + client.auth = HTTPBasicAuth('user', 'pass') + client.headers.update({'x-test': 'true'}) + +## CSRF + +If you're using `SessionAuthentication` then you'll need to include a CSRF token +for any `POST`, `PUT`, `PATCH` or `DELETE` requests. + +You can do so by following the same flow that a JavaScript based client would use. +First make a `GET` request in order to obtain a CRSF token, then present that +token in the following request. + +For example... + + client = RequestsClient() + + # Obtain a CSRF token. + response = client.get('/homepage/') + assert response.status_code == 200 + csrftoken = response.cookies['csrftoken'] + + # Interact with the API. + response = client.post('/organisations/', json={ + 'name': 'MegaCorp', + 'status': 'active' + }, headers={'X-CSRFToken': csrftoken}) + assert response.status_code == 200 + +## Live tests + +With careful usage both the `RequestsClient` and the `CoreAPIClient` provide +the ability to write test cases that can run either in development, or be run +directly against your staging server or production environment. + +Using this style to create basic tests of a few core piece of functionality is +a powerful way to validate your live service. Doing so may require some careful +attention to setup and teardown to ensure that the tests run in a way that they +do not directly affect customer data. + +--- + +# CoreAPIClient + +The CoreAPIClient allows you to interact with your API using the Python +`coreapi` client library. + + # Fetch the API schema + url = reverse('schema') + client = CoreAPIClient() + schema = client.get(url) + + # Create a new organisation + params = {'name': 'MegaCorp', 'status': 'active'} + client.action(schema, ['organisations', 'create'], params) + + # Ensure that the organisation exists in the listing + data = client.action(schema, ['organisations', 'list']) + assert(len(data) == 1) + assert(data == [{'name': 'MegaCorp', 'status': 'active'}]) + +## Headers & Authentication + +Custom headers and authentication may be used with `CoreAPIClient` in a +similar way as with `RequestsClient`. + + from requests.auth import HTTPBasicAuth + + client = CoreAPIClient() + client.session.auth = HTTPBasicAuth('user', 'pass') + client.session.headers.update({'x-test': 'true'}) + +--- + # Test cases REST framework includes the following test case classes, that mirror the existing Django test case classes, but use `APIClient` instead of Django's default `Client`. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index a02a0f1bf..859b60460 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, renderers, views from rest_framework.compat import NoReverseMatch +from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -281,7 +282,7 @@ class DefaultRouter(SimpleRouter): include_root_view = True include_format_suffixes = True root_view_name = 'api-root' - default_schema_renderers = [renderers.CoreJSONRenderer] + default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer] def __init__(self, *args, **kwargs): if 'schema_renderers' in kwargs: diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1c2f1a546..0928e7661 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from importlib import import_module from django.conf import settings @@ -55,6 +56,18 @@ def is_custom_action(action): ]) +def endpoint_ordering(endpoint): + path, method, callback = endpoint + method_priority = { + 'GET': 0, + 'POST': 1, + 'PUT': 2, + 'PATCH': 3, + 'DELETE': 4 + }.get(method, 5) + return (path, method_priority) + + class EndpointInspector(object): """ A class to determine the available API endpoints that a project exposes. @@ -101,6 +114,8 @@ class EndpointInspector(object): ) api_endpoints.extend(nested_endpoints) + api_endpoints = sorted(api_endpoints, key=endpoint_ordering) + return api_endpoints def get_path_from_regex(self, path_regex): @@ -183,7 +198,7 @@ class SchemaGenerator(object): Return a dictionary containing all the links that should be included in the API schema. """ - links = {} + links = OrderedDict() for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) if not self.has_view_permissions(view): diff --git a/rest_framework/test.py b/rest_framework/test.py index 1b3ad80c2..16b1b4cd5 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -7,6 +7,7 @@ from __future__ import unicode_literals import io from django.conf import settings +from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient @@ -105,36 +106,46 @@ if requests is not None: def close(self): pass - class DjangoTestSession(requests.Session): + class NoExternalRequestsAdapter(requests.adapters.HTTPAdapter): + def send(self, request, *args, **kwargs): + msg = ( + 'RequestsClient refusing to make an outgoing network request ' + 'to "%s". Only "testserver" or hostnames in your ALLOWED_HOSTS ' + 'setting are valid.' % request.url + ) + raise RuntimeError(msg) + + class RequestsClient(requests.Session): def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) - + super(RequestsClient, self).__init__(*args, **kwargs) adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) + self.mount('http://', adapter) + self.mount('https://', adapter) def request(self, method, url, *args, **kwargs): if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) + return super(RequestsClient, self).request(method, url, *args, **kwargs) + +else: + def RequestsClient(*args, **kwargs): + raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') -def get_requests_client(): - assert requests is not None, 'requests must be installed' - return DjangoTestSession() +if coreapi is not None: + class CoreAPIClient(coreapi.Client): + def __init__(self, *args, **kwargs): + self._session = RequestsClient() + kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] + return super(CoreAPIClient, self).__init__(*args, **kwargs) + @property + def session(self): + return self._session -def get_api_client(): - assert coreapi is not None, 'coreapi must be installed' - session = get_requests_client() - return coreapi.Client(transports=[ - coreapi.transports.HTTPTransport(session=session) - ]) +else: + def CoreAPIClient(*args, **kwargs): + raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') class APIRequestFactory(DjangoRequestFactory): diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 9daf3f3fe..a6d72357a 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -12,7 +12,7 @@ from rest_framework.compat import coreapi from rest_framework.parsers import FileUploadParser from rest_framework.renderers import CoreJSONRenderer from rest_framework.response import Response -from rest_framework.test import APITestCase, get_api_client +from rest_framework.test import APITestCase, CoreAPIClient from rest_framework.views import APIView @@ -22,6 +22,7 @@ def get_schema(): title='Example API', content={ 'simple_link': coreapi.Link('/example/', description='example link'), + 'headers': coreapi.Link('/headers/'), 'location': { 'query': coreapi.Link('/example/', fields=[ coreapi.Field(name='example', description='example field') @@ -165,6 +166,19 @@ class TextView(APIView): return HttpResponse('123', content_type='text/plain') +class HeadersView(APIView): + def get(self, request): + headers = { + key[5:].replace('_', '-'): value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) + + urlpatterns = [ url(r'^$', SchemaView.as_view()), url(r'^example/$', ListView.as_view()), @@ -172,6 +186,7 @@ urlpatterns = [ url(r'^upload/$', UploadView.as_view()), url(r'^download/$', DownloadView.as_view()), url(r'^text/$', TextView.as_view()), + url(r'^headers/$', HeadersView.as_view()), ] @@ -179,7 +194,7 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_api_client') class APIClientTests(APITestCase): def test_api_client(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') assert schema.title == 'Example API' assert schema.url == 'https://api.example.com/' @@ -193,7 +208,7 @@ class APIClientTests(APITestCase): assert data == expected def test_query_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'query'], params={'example': 123}) expected = { @@ -202,8 +217,15 @@ class APIClientTests(APITestCase): } assert data == expected + def test_session_headers(self): + client = CoreAPIClient() + client.session.headers.update({'X-Custom-Header': 'foo'}) + schema = client.get('http://api.example.com/') + data = client.action(schema, ['headers']) + assert data['headers']['X-CUSTOM-HEADER'] == 'foo' + def test_query_params_with_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) expected = { @@ -213,7 +235,7 @@ class APIClientTests(APITestCase): assert data == expected def test_form_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'form'], params={'example': 123}) expected = { @@ -226,7 +248,7 @@ class APIClientTests(APITestCase): assert data == expected def test_body_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'body'], params={'example': 123}) expected = { @@ -239,7 +261,7 @@ class APIClientTests(APITestCase): assert data == expected def test_path_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'path'], params={'id': 123}) expected = { @@ -250,7 +272,7 @@ class APIClientTests(APITestCase): assert data == expected def test_multipart_encoding(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') temp = tempfile.NamedTemporaryFile() @@ -272,7 +294,7 @@ class APIClientTests(APITestCase): def test_multipart_encoding_no_file(self): # When no file is included, multipart encoding should still be used. - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) @@ -287,7 +309,7 @@ class APIClientTests(APITestCase): assert data == expected def test_multipart_encoding_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) @@ -305,7 +327,7 @@ class APIClientTests(APITestCase): # Test for `coreapi.utils.File` support. from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File(name='example.txt', content='123') @@ -323,7 +345,7 @@ class APIClientTests(APITestCase): def test_multipart_encoding_in_body(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} @@ -341,7 +363,7 @@ class APIClientTests(APITestCase): # URLencoded def test_urlencoded_encoding(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) expected = { @@ -354,7 +376,7 @@ class APIClientTests(APITestCase): assert data == expected def test_urlencoded_encoding_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) expected = { @@ -367,7 +389,7 @@ class APIClientTests(APITestCase): assert data == expected def test_urlencoded_encoding_in_body(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) expected = { @@ -382,7 +404,7 @@ class APIClientTests(APITestCase): # Raw uploads def test_raw_upload(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') temp = tempfile.NamedTemporaryFile() @@ -403,7 +425,7 @@ class APIClientTests(APITestCase): def test_raw_upload_string_file_content(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File('example.txt', '123') @@ -419,7 +441,7 @@ class APIClientTests(APITestCase): def test_raw_upload_explicit_content_type(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File('example.txt', '123', 'text/html') @@ -435,7 +457,7 @@ class APIClientTests(APITestCase): # Responses def test_text_response(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['response', 'text']) @@ -444,7 +466,7 @@ class APIClientTests(APITestCase): assert data == expected def test_download_response(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['response', 'download']) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 37bde1092..791ca4ff2 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase, get_requests_client +from rest_framework.test import APITestCase, RequestsClient from rest_framework.views import APIView @@ -92,10 +92,10 @@ class AuthView(APIView): urlpatterns = [ - url(r'^$', Root.as_view()), - url(r'^headers/$', HeadersView.as_view()), - url(r'^session/$', SessionView.as_view()), - url(r'^auth/$', AuthView.as_view()), + url(r'^$', Root.as_view(), name='root'), + url(r'^headers/$', HeadersView.as_view(), name='headers'), + url(r'^session/$', SessionView.as_view(), name='session'), + url(r'^auth/$', AuthView.as_view(), name='auth'), ] @@ -103,8 +103,8 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - client = get_requests_client() - response = client.get('/') + client = RequestsClient() + response = client.get('http://testserver/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -114,8 +114,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_in_url(self): - client = get_requests_client() - response = client.get('/?key=value') + client = RequestsClient() + response = client.get('http://testserver/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -125,8 +125,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - client = get_requests_client() - response = client.get('/', params={'key': 'value'}) + client = RequestsClient() + response = client.get('http://testserver/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -136,16 +136,25 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_with_headers(self): - client = get_requests_client() - response = client.get('/headers/', headers={'User-Agent': 'example'}) + client = RequestsClient() + response = client.get('http://testserver/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_get_with_session_headers(self): + client = RequestsClient() + client.headers.update({'User-Agent': 'example'}) + response = client.get('http://testserver/headers/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - client = get_requests_client() - response = client.post('/', data={'key': 'value'}) + client = RequestsClient() + response = client.post('http://testserver/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -158,8 +167,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_json_request(self): - client = get_requests_client() - response = client.post('/', json={'key': 'value'}) + client = RequestsClient() + response = client.post('http://testserver/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -172,11 +181,11 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_multipart_request(self): - client = get_requests_client() + client = RequestsClient() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = client.post('/', files=files) + response = client.post('http://testserver/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -189,20 +198,20 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_session(self): - client = get_requests_client() - response = client.get('/session/') + client = RequestsClient() + response = client.get('http://testserver/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = client.post('/session/', json={'example': 'abc'}) + response = client.post('http://testserver/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = client.get('/session/') + response = client.get('http://testserver/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -210,8 +219,8 @@ class RequestsClientTests(APITestCase): def test_auth(self): # Confirm session is not authenticated - client = get_requests_client() - response = client.get('/auth/') + client = RequestsClient() + response = client.get('http://testserver/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -226,7 +235,7 @@ class RequestsClientTests(APITestCase): user.save() # Perform a login - response = client.post('/auth/', json={ + response = client.post('http://testserver/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -238,7 +247,7 @@ class RequestsClientTests(APITestCase): assert response.json() == expected # Confirm session is authenticated - response = client.get('/auth/') + response = client.get('http://testserver/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {