Graceful fallback if requests is not installed.

This commit is contained in:
Tom Christie 2016-08-17 12:11:01 +01:00
parent e76ca6eb88
commit 6ede654315
3 changed files with 78 additions and 68 deletions

View File

@ -178,6 +178,13 @@ except (ImportError, SyntaxError):
uritemplate = None uritemplate = None
# requests is optional
try:
import requests
except ImportError:
requests = None
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite. # Fixes (#1712). We keep the try/except for the test suite.
guardian = None guardian = None

View File

@ -15,12 +15,8 @@ from django.test.client import ClientHandler
from django.utils import six from django.utils import six
from django.utils.encoding import force_bytes from django.utils.encoding import force_bytes
from django.utils.http import urlencode from django.utils.http import urlencode
from requests import Session
from requests.adapters import BaseAdapter
from requests.models import Response
from requests.structures import CaseInsensitiveDict
from requests.utils import get_encoding_from_headers
from rest_framework.compat import requests
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token request._force_auth_token = token
class DjangoTestAdapter(BaseAdapter): if requests is not None:
""" class DjangoTestAdapter(requests.adapters.BaseAdapter):
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
def get_environ(self, request):
""" """
Given a `requests.PreparedRequest` instance, return a WSGI environ dict. A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
""" """
method = request.method def __init__(self):
url = request.url self.app = WSGIHandler()
kwargs = {} self.factory = DjangoRequestFactory()
# Set request content, if any exists. def get_environ(self, request):
if request.body is not None: """
kwargs['data'] = request.body Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
if 'content-type' in request.headers: """
kwargs['content_type'] = request.headers['content-type'] method = request.method
url = request.url
kwargs = {}
# Set request headers. # Set request content, if any exists.
for key, value in request.headers.items(): if request.body is not None:
key = key.upper() kwargs['data'] = request.body
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): if 'content-type' in request.headers:
continue kwargs['content_type'] = request.headers['content-type']
kwargs['HTTP_%s' % key.replace('-', '_')] = value
return self.factory.generic(method, url, **kwargs).environ # Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value
def send(self, request, *args, **kwargs): return self.factory.generic(method, url, **kwargs).environ
"""
Make an outgoing request to the Django WSGI application.
"""
response = Response()
def start_response(status, headers): def send(self, request, *args, **kwargs):
status_code, _, reason_phrase = status.partition(' ') """
response.status_code = int(status_code) Make an outgoing request to the Django WSGI application.
response.reason = reason_phrase """
response.headers = CaseInsensitiveDict(headers) response = requests.models.Response()
response.encoding = get_encoding_from_headers(response.headers)
environ = self.get_environ(request) def start_response(status, headers):
raw_bytes = self.app(environ, start_response) status_code, _, reason_phrase = status.partition(' ')
response.status_code = int(status_code)
response.reason = reason_phrase
response.headers = requests.structures.CaseInsensitiveDict(headers)
response.encoding = requests.utils.get_encoding_from_headers(response.headers)
response.request = request environ = self.get_environ(request)
response.url = request.url raw_bytes = self.app(environ, start_response)
response.raw = io.BytesIO(b''.join(raw_bytes))
return response response.request = request
response.url = request.url
response.raw = io.BytesIO(b''.join(raw_bytes))
def close(self): return response
pass
def close(self):
pass
class DjangoTestSession(Session): class DjangoTestSession(requests.Session):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs) super(DjangoTestSession, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter() adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
for hostname in hostnames: for hostname in hostnames:
if hostname == '*': if hostname == '*':
hostname = '' hostname = ''
self.mount('http://%s' % hostname, adapter) self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter) self.mount('https://%s' % hostname, adapter)
def request(self, method, url, *args, **kwargs): def request(self, method, url, *args, **kwargs):
if ':' not in url: if ':' not in url:
url = 'http://testserver/' + url.lstrip('/') url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs) return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
class APIRequestFactory(DjangoRequestFactory): class APIRequestFactory(DjangoRequestFactory):
@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase):
class APITestCase(testcases.TestCase): class APITestCase(testcases.TestCase):
client_class = APIClient client_class = APIClient
def _pre_setup(self): @property
super(APITestCase, self)._pre_setup() def requests(self):
self.requests = DjangoTestSession() if not hasattr(self, '_requests'):
assert requests is not None, 'requests must be installed'
self._requests = DjangoTestSession()
return self._requests
class APISimpleTestCase(testcases.SimpleTestCase): class APISimpleTestCase(testcases.SimpleTestCase):

View File

@ -1,8 +1,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import unittest
from django.conf.urls import url from django.conf.urls import url
from django.test import override_settings from django.test import override_settings
from rest_framework.compat import requests
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from rest_framework.views import APIView from rest_framework.views import APIView
@ -37,7 +40,7 @@ class Root(APIView):
class Headers(APIView): class Headers(APIView):
def get(self, request): def get(self, request):
headers = { headers = {
key[5:]: value key[5:].replace('_', '-'): value
for key, value in request.META.items() for key, value in request.META.items()
if key.startswith('HTTP_') if key.startswith('HTTP_')
} }
@ -53,6 +56,7 @@ urlpatterns = [
] ]
@unittest.skipUnless(requests, 'requests not installed')
@override_settings(ROOT_URLCONF='tests.test_requests_client') @override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase): class RequestsClientTests(APITestCase):
def test_get_request(self): def test_get_request(self):