From 6ede654315e415362f4c7c8e38a3f641039bfc46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 12:11:01 +0100 Subject: [PATCH] Graceful fallback if requests is not installed. --- rest_framework/compat.py | 7 ++ rest_framework/test.py | 133 +++++++++++++++++----------------- tests/test_requests_client.py | 6 +- 3 files changed, 78 insertions(+), 68 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84..bda346fa8 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -178,6 +178,13 @@ except (ImportError, SyntaxError): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None diff --git a/rest_framework/test.py b/rest_framework/test.py index e1d8eff82..eba4b96cf 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -15,12 +15,8 @@ from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from requests import Session -from requests.adapters import BaseAdapter -from requests.models import Response -from requests.structures import CaseInsensitiveDict -from requests.utils import get_encoding_from_headers +from rest_framework.compat import requests from rest_framework.settings import api_settings @@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token -class DjangoTestAdapter(BaseAdapter): - """ - A transport adapter for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests over the network. - """ - def __init__(self): - self.app = WSGIHandler() - self.factory = DjangoRequestFactory() - - def get_environ(self, request): +if requests is not None: + class DjangoTestAdapter(requests.adapters.BaseAdapter): """ - Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ - method = request.method - url = request.url - kwargs = {} + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() - # Set request content, if any exists. - if request.body is not None: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} - # Set request headers. - for key, value in request.headers.items(): - key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): - continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] - return self.factory.generic(method, url, **kwargs).environ + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value - def send(self, request, *args, **kwargs): - """ - Make an outgoing request to the Django WSGI application. - """ - response = Response() + return self.factory.generic(method, url, **kwargs).environ - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = CaseInsensitiveDict(headers) - response.encoding = get_encoding_from_headers(response.headers) + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = requests.models.Response() - environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = requests.structures.CaseInsensitiveDict(headers) + response.encoding = requests.utils.get_encoding_from_headers(response.headers) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) - return response + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) - def close(self): - pass + return response + def close(self): + pass -class DjangoTestSession(Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) - adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) - def request(self, method, url, *args, **kwargs): - if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) class APIRequestFactory(DjangoRequestFactory): @@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - def _pre_setup(self): - super(APITestCase, self)._pre_setup() - self.requests = DjangoTestSession() + @property + def requests(self): + if not hasattr(self, '_requests'): + assert requests is not None, 'requests must be installed' + self._requests = DjangoTestSession() + return self._requests class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 0687dd92e..24e29d3b8 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -1,8 +1,11 @@ from __future__ import unicode_literals +import unittest + from django.conf.urls import url from django.test import override_settings +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -37,7 +40,7 @@ class Root(APIView): class Headers(APIView): def get(self, request): headers = { - key[5:]: value + key[5:].replace('_', '-'): value for key, value in request.META.items() if key.startswith('HTTP_') } @@ -53,6 +56,7 @@ urlpatterns = [ ] +@unittest.skipUnless(requests, 'requests not installed') @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self):