From 08c7853655252cb9de0ade24eddfc9762a01cee7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Aug 2016 14:15:35 +0100 Subject: [PATCH 01/10] Start test case --- rest_framework/test.py | 5 +++++ tests/test_requests_client.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/test_requests_client.py diff --git a/rest_framework/test.py b/rest_framework/test.py index 3ba4059a9..fb08c4a73 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -12,6 +12,7 @@ from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode +from requests import Session from rest_framework.settings import api_settings @@ -221,6 +222,10 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient + def _pre_setup(self): + super(APITestCase, self)._pre_setup() + self.requests = Session() + class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py new file mode 100644 index 000000000..a36349a3f --- /dev/null +++ b/tests/test_requests_client.py @@ -0,0 +1,24 @@ +from __future__ import unicode_literals + +from django.conf.urls import url +from django.test import override_settings + +from rest_framework.response import Response +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({'hello': 'world'}) + + +urlpatterns = [ + url(r'^$', Root.as_view()), +] + + +@override_settings(ROOT_URLCONF='tests.test_requests_client') +class RequestsClientTests(APITestCase): + def test_get_root(self): + print self.requests.get('http://example.com') From 3d1fff3f26835612be17b6624d766a56520880ec Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 15:42:35 +0100 Subject: [PATCH 02/10] Added 'requests' test client --- rest_framework/request.py | 2 +- rest_framework/test.py | 86 +++++++++++++++++++++++- tests/test_requests_client.py | 119 +++++++++++++++++++++++++++++++++- 3 files changed, 202 insertions(+), 5 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index aafafcb32..f5738bfd5 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -373,7 +373,7 @@ class Request(object): if not _hasattr(self, '_data'): self._load_data_and_files() if is_form_media_type(self.content_type): - return self.data + return self._data return QueryDict('', encoding=self._request._encoding) @property diff --git a/rest_framework/test.py b/rest_framework/test.py index fb08c4a73..1fd530a0c 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,10 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode from requests import Session +from requests.adapters import BaseAdapter +from requests.models import Response +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers from rest_framework.settings import api_settings @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +class DjangoTestAdapter(BaseAdapter): + """ + A transport adaptor for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests ovet the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = CaseInsensitiveDict(headers) + response.encoding = get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + +class DjangoTestSession(Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase): def _pre_setup(self): super(APITestCase, self)._pre_setup() - self.requests = Session() + self.requests = DjangoTestSession() class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index a36349a3f..0687dd92e 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,15 +10,128 @@ from rest_framework.views import APIView class Root(APIView): def get(self, request): - return Response({'hello': 'world'}) + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class Headers(APIView): + def get(self, request): + headers = { + key[5:]: value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) urlpatterns = [ url(r'^$', Root.as_view()), + url(r'^headers/$', Headers.as_view()), ] @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): - def test_get_root(self): - print self.requests.get('http://example.com') + def test_get_request(self): + response = self.requests.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + response = self.requests.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + response = self.requests.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + response = self.requests.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + response = self.requests.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = self.requests.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + # cookies/session auth From e76ca6eb8838148ccb1c25a2a8a735a42f644d99 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 16:06:04 +0100 Subject: [PATCH 03/10] Address typos --- rest_framework/test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 1fd530a0c..e1d8eff82 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -31,8 +31,8 @@ def force_authenticate(request, user=None, token=None): class DjangoTestAdapter(BaseAdapter): """ - A transport adaptor for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests ovet the network. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() @@ -55,9 +55,9 @@ class DjangoTestAdapter(BaseAdapter): # Set request headers. for key, value in request.headers.items(): key = key.upper() - if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue - kwargs['HTTP_%s' % key] = value + kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ From 6ede654315e415362f4c7c8e38a3f641039bfc46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 12:11:01 +0100 Subject: [PATCH 04/10] Graceful fallback if requests is not installed. --- rest_framework/compat.py | 7 ++ rest_framework/test.py | 133 +++++++++++++++++----------------- tests/test_requests_client.py | 6 +- 3 files changed, 78 insertions(+), 68 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84..bda346fa8 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -178,6 +178,13 @@ except (ImportError, SyntaxError): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None diff --git a/rest_framework/test.py b/rest_framework/test.py index e1d8eff82..eba4b96cf 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -15,12 +15,8 @@ from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from requests import Session -from requests.adapters import BaseAdapter -from requests.models import Response -from requests.structures import CaseInsensitiveDict -from requests.utils import get_encoding_from_headers +from rest_framework.compat import requests from rest_framework.settings import api_settings @@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token -class DjangoTestAdapter(BaseAdapter): - """ - A transport adapter for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests over the network. - """ - def __init__(self): - self.app = WSGIHandler() - self.factory = DjangoRequestFactory() - - def get_environ(self, request): +if requests is not None: + class DjangoTestAdapter(requests.adapters.BaseAdapter): """ - Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ - method = request.method - url = request.url - kwargs = {} + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() - # Set request content, if any exists. - if request.body is not None: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} - # Set request headers. - for key, value in request.headers.items(): - key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): - continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] - return self.factory.generic(method, url, **kwargs).environ + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value - def send(self, request, *args, **kwargs): - """ - Make an outgoing request to the Django WSGI application. - """ - response = Response() + return self.factory.generic(method, url, **kwargs).environ - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = CaseInsensitiveDict(headers) - response.encoding = get_encoding_from_headers(response.headers) + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = requests.models.Response() - environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = requests.structures.CaseInsensitiveDict(headers) + response.encoding = requests.utils.get_encoding_from_headers(response.headers) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) - return response + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) - def close(self): - pass + return response + def close(self): + pass -class DjangoTestSession(Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) - adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) - def request(self, method, url, *args, **kwargs): - if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) class APIRequestFactory(DjangoRequestFactory): @@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - def _pre_setup(self): - super(APITestCase, self)._pre_setup() - self.requests = DjangoTestSession() + @property + def requests(self): + if not hasattr(self, '_requests'): + assert requests is not None, 'requests must be installed' + self._requests = DjangoTestSession() + return self._requests class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 0687dd92e..24e29d3b8 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -1,8 +1,11 @@ from __future__ import unicode_literals +import unittest + from django.conf.urls import url from django.test import override_settings +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -37,7 +40,7 @@ class Root(APIView): class Headers(APIView): def get(self, request): headers = { - key[5:]: value + key[5:].replace('_', '-'): value for key, value in request.META.items() if key.startswith('HTTP_') } @@ -53,6 +56,7 @@ urlpatterns = [ ] +@unittest.skipUnless(requests, 'requests not installed') @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): From 049a39e060ab8bbd028ffaa0fb2f5104a71f400d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 15:43:12 +0100 Subject: [PATCH 05/10] Add cookie support --- rest_framework/test.py | 41 +++++++++++++------ tests/test_requests_client.py | 76 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index eba4b96cf..bc8ecc5db 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,7 +26,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: - class DjangoTestAdapter(requests.adapters.BaseAdapter): + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. @@ -62,23 +62,38 @@ if requests is not None: """ Make an outgoing request to the Django WSGI application. """ - response = requests.models.Response() + raw_kwargs = {} - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = requests.structures.CaseInsensitiveDict(headers) - response.encoding = requests.utils.get_encoding_from_headers(response.headers) + def start_response(wsgi_status, wsgi_headers): + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) + self.closed = False + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + + status, _, reason = wsgi_status.partition(' ') + raw_kwargs['status'] = int(status) + raw_kwargs['reason'] = reason + raw_kwargs['headers'] = wsgi_headers + raw_kwargs['version'] = 11 + raw_kwargs['preload_content'] = False + raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + + # Make the outgoing request via WSGI. environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + wsgi_response = self.app(environ, start_response) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + # Build the underlying urllib3.HTTPResponse + raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - return response + # Build the requests.Response + return self.build_response(request, raw) def close(self): pass diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 24e29d3b8..10158efa7 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -37,7 +37,7 @@ class Root(APIView): }) -class Headers(APIView): +class HeadersView(APIView): def get(self, request): headers = { key[5:].replace('_', '-'): value @@ -50,9 +50,32 @@ class Headers(APIView): }) +class SessionView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.session.items() + }) + + def post(self, request): + for key, value in request.data.items(): + request.session[key] = value + return Response({ + key: value for key, value in request.session.items() + }) + + +class CookiesView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.COOKIES.items() + }) + + urlpatterns = [ url(r'^$', Root.as_view()), - url(r'^headers/$', Headers.as_view()), + url(r'^headers/$', HeadersView.as_view()), + url(r'^session/$', SessionView.as_view()), + url(r'^cookies/$', CookiesView.as_view()), ] @@ -138,4 +161,53 @@ class RequestsClientTests(APITestCase): } assert response.json() == expected + def test_session(self): + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {} + assert response.json() == expected + + response = self.requests.post('/session/', json={'example': 'abc'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + def test_cookies(self): + """ + Test for explicitly setting a cookie. + """ + my_cookie = { + "version": 0, + "name": 'COOKIE_NAME', + "value": 'COOKIE_VALUE', + "port": None, + # "port_specified":False, + "domain": 'testserver.local', + # "domain_specified":False, + # "domain_initial_dot":False, + "path": '/', + # "path_specified":True, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {}, + "rfc2109": False + } + self.requests.cookies.set(**my_cookie) + response = self.requests.get('/cookies/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + assert response.json() == expected + # cookies/session auth From 64e19c738fce463df6fafd8f377ecc702f068813 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 17:54:03 +0100 Subject: [PATCH 06/10] Tests for auth and CSRF --- tests/test_requests_client.py | 87 ++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 10158efa7..aa99a71da 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -3,9 +3,14 @@ from __future__ import unicode_literals import unittest from django.conf.urls import url +from django.contrib.auth import authenticate, login +from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import override_settings +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie -from rest_framework.compat import requests +from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -64,18 +69,33 @@ class SessionView(APIView): }) -class CookiesView(APIView): +class AuthView(APIView): + @method_decorator(ensure_csrf_cookie) def get(self, request): + if is_authenticated(request.user): + username = request.user.username + else: + username = None return Response({ - key: value for key, value in request.COOKIES.items() + 'username': username }) + @method_decorator(csrf_protect) + def post(self, request): + username = request.data['username'] + password = request.data['password'] + user = authenticate(username=username, password=password) + if user is None: + return Response({'error': 'incorrect credentials'}) + login(request, user) + return redirect('/auth/') + urlpatterns = [ url(r'^$', Root.as_view()), url(r'^headers/$', HeadersView.as_view()), url(r'^session/$', SessionView.as_view()), - url(r'^cookies/$', CookiesView.as_view()), + url(r'^auth/$', AuthView.as_view()), ] @@ -180,34 +200,39 @@ class RequestsClientTests(APITestCase): expected = {'example': 'abc'} assert response.json() == expected - def test_cookies(self): - """ - Test for explicitly setting a cookie. - """ - my_cookie = { - "version": 0, - "name": 'COOKIE_NAME', - "value": 'COOKIE_VALUE', - "port": None, - # "port_specified":False, - "domain": 'testserver.local', - # "domain_specified":False, - # "domain_initial_dot":False, - "path": '/', - # "path_specified":True, - "secure": False, - "expires": None, - "discard": True, - "comment": None, - "comment_url": None, - "rest": {}, - "rfc2109": False - } - self.requests.cookies.set(**my_cookie) - response = self.requests.get('/cookies/') + def test_auth(self): + # Confirm session is not authenticated + response = self.requests.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' - expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + expected = { + 'username': None + } + assert response.json() == expected + assert 'csrftoken' in response.cookies + csrftoken = response.cookies['csrftoken'] + + user = User.objects.create(username='tom') + user.set_password('password') + user.save() + + # Perform a login + response = self.requests.post('/auth/', json={ + 'username': 'tom', + 'password': 'password' + }, headers={'X-CSRFToken': csrftoken}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } assert response.json() == expected - # cookies/session auth + # Confirm session is authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected From da47c345c09eaf4fffca9cfcd18bba7acd844e8b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:09:19 +0100 Subject: [PATCH 07/10] Py3 compat --- rest_framework/test.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bc8ecc5db..a95d18537 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,6 +26,21 @@ def force_authenticate(request, user=None, token=None): if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key): + return self.getheaders(self, key) + + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = HeaderDict(headers) + self.closed = False + + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the @@ -65,17 +80,6 @@ if requests is not None: raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): - class MockOriginalResponse(object): - def __init__(self, headers): - self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) - self.closed = False - - def isclosed(self): - return self.closed - - def close(self): - self.closed = True - status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason From 53117698e042691a7045688d41edae9cc118643b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:47:01 +0100 Subject: [PATCH 08/10] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index a95d18537..bf22ff08d 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -27,7 +27,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key): + def get_all(self, key, default): return self.getheaders(self, key) class MockOriginalResponse(object): From 0b3db028a2a7a0a91c3111fc6febbdbcd9cbd6b5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:50:02 +0100 Subject: [PATCH 09/10] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bf22ff08d..ded9d5fe9 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -28,7 +28,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): - return self.getheaders(self, key) + return self.getheaders(key) class MockOriginalResponse(object): def __init__(self, headers): From 0cc3f5008fcd41d1597776c43dc57502ed2a7542 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Aug 2016 15:34:19 +0100 Subject: [PATCH 10/10] Add get_requests_client --- rest_framework/test.py | 12 +++++------- tests/test_requests_client.py | 37 ++++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index ded9d5fe9..e17c19a43 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -121,6 +121,11 @@ if requests is not None: return super(DjangoTestSession, self).request(method, url, *args, **kwargs) +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - @property - def requests(self): - if not hasattr(self, '_requests'): - assert requests is not None, 'requests must be installed' - self._requests = DjangoTestSession() - return self._requests - class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index aa99a71da..37bde1092 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase +from rest_framework.test import APITestCase, get_requests_client from rest_framework.views import APIView @@ -103,7 +103,8 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - response = self.requests.get('/') + client = get_requests_client() + response = client.get('/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -113,7 +114,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_in_url(self): - response = self.requests.get('/?key=value') + client = get_requests_client() + response = client.get('/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -123,7 +125,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - response = self.requests.get('/', params={'key': 'value'}) + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -133,14 +136,16 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_with_headers(self): - response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - response = self.requests.post('/', data={'key': 'value'}) + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -153,7 +158,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_json_request(self): - response = self.requests.post('/', json={'key': 'value'}) + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -166,10 +172,11 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_multipart_request(self): + client = get_requests_client() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = self.requests.post('/', files=files) + response = client.post('/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -182,19 +189,20 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_session(self): - response = self.requests.get('/session/') + client = get_requests_client() + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = self.requests.post('/session/', json={'example': 'abc'}) + response = client.post('/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = self.requests.get('/session/') + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -202,7 +210,8 @@ class RequestsClientTests(APITestCase): def test_auth(self): # Confirm session is not authenticated - response = self.requests.get('/auth/') + client = get_requests_client() + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -217,7 +226,7 @@ class RequestsClientTests(APITestCase): user.save() # Perform a login - response = self.requests.post('/auth/', json={ + response = client.post('/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -229,7 +238,7 @@ class RequestsClientTests(APITestCase): assert response.json() == expected # Confirm session is authenticated - response = self.requests.get('/auth/') + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {