mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Graceful fallback if requests is not installed.
This commit is contained in:
parent
e76ca6eb88
commit
6ede654315
|
@ -178,6 +178,13 @@ except (ImportError, SyntaxError):
|
|||
uritemplate = None
|
||||
|
||||
|
||||
# requests is optional
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
||||
# Fixes (#1712). We keep the try/except for the test suite.
|
||||
guardian = None
|
||||
|
|
|
@ -15,12 +15,8 @@ from django.test.client import ClientHandler
|
|||
from django.utils import six
|
||||
from django.utils.encoding import force_bytes
|
||||
from django.utils.http import urlencode
|
||||
from requests import Session
|
||||
from requests.adapters import BaseAdapter
|
||||
from requests.models import Response
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
from requests.utils import get_encoding_from_headers
|
||||
|
||||
from rest_framework.compat import requests
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None):
|
|||
request._force_auth_token = token
|
||||
|
||||
|
||||
class DjangoTestAdapter(BaseAdapter):
|
||||
"""
|
||||
A transport adapter for `requests`, that makes requests via the
|
||||
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.app = WSGIHandler()
|
||||
self.factory = DjangoRequestFactory()
|
||||
|
||||
def get_environ(self, request):
|
||||
if requests is not None:
|
||||
class DjangoTestAdapter(requests.adapters.BaseAdapter):
|
||||
"""
|
||||
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||
A transport adapter for `requests`, that makes requests via the
|
||||
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||
"""
|
||||
method = request.method
|
||||
url = request.url
|
||||
kwargs = {}
|
||||
def __init__(self):
|
||||
self.app = WSGIHandler()
|
||||
self.factory = DjangoRequestFactory()
|
||||
|
||||
# Set request content, if any exists.
|
||||
if request.body is not None:
|
||||
kwargs['data'] = request.body
|
||||
if 'content-type' in request.headers:
|
||||
kwargs['content_type'] = request.headers['content-type']
|
||||
def get_environ(self, request):
|
||||
"""
|
||||
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||
"""
|
||||
method = request.method
|
||||
url = request.url
|
||||
kwargs = {}
|
||||
|
||||
# Set request headers.
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper()
|
||||
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||
continue
|
||||
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||
# Set request content, if any exists.
|
||||
if request.body is not None:
|
||||
kwargs['data'] = request.body
|
||||
if 'content-type' in request.headers:
|
||||
kwargs['content_type'] = request.headers['content-type']
|
||||
|
||||
return self.factory.generic(method, url, **kwargs).environ
|
||||
# Set request headers.
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper()
|
||||
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||
continue
|
||||
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
"""
|
||||
Make an outgoing request to the Django WSGI application.
|
||||
"""
|
||||
response = Response()
|
||||
return self.factory.generic(method, url, **kwargs).environ
|
||||
|
||||
def start_response(status, headers):
|
||||
status_code, _, reason_phrase = status.partition(' ')
|
||||
response.status_code = int(status_code)
|
||||
response.reason = reason_phrase
|
||||
response.headers = CaseInsensitiveDict(headers)
|
||||
response.encoding = get_encoding_from_headers(response.headers)
|
||||
def send(self, request, *args, **kwargs):
|
||||
"""
|
||||
Make an outgoing request to the Django WSGI application.
|
||||
"""
|
||||
response = requests.models.Response()
|
||||
|
||||
environ = self.get_environ(request)
|
||||
raw_bytes = self.app(environ, start_response)
|
||||
def start_response(status, headers):
|
||||
status_code, _, reason_phrase = status.partition(' ')
|
||||
response.status_code = int(status_code)
|
||||
response.reason = reason_phrase
|
||||
response.headers = requests.structures.CaseInsensitiveDict(headers)
|
||||
response.encoding = requests.utils.get_encoding_from_headers(response.headers)
|
||||
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.raw = io.BytesIO(b''.join(raw_bytes))
|
||||
environ = self.get_environ(request)
|
||||
raw_bytes = self.app(environ, start_response)
|
||||
|
||||
return response
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.raw = io.BytesIO(b''.join(raw_bytes))
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
return response
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class DjangoTestSession(Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||
class DjangoTestSession(requests.Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||
|
||||
adapter = DjangoTestAdapter()
|
||||
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||
adapter = DjangoTestAdapter()
|
||||
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||
|
||||
for hostname in hostnames:
|
||||
if hostname == '*':
|
||||
hostname = ''
|
||||
self.mount('http://%s' % hostname, adapter)
|
||||
self.mount('https://%s' % hostname, adapter)
|
||||
for hostname in hostnames:
|
||||
if hostname == '*':
|
||||
hostname = ''
|
||||
self.mount('http://%s' % hostname, adapter)
|
||||
self.mount('https://%s' % hostname, adapter)
|
||||
|
||||
def request(self, method, url, *args, **kwargs):
|
||||
if ':' not in url:
|
||||
url = 'http://testserver/' + url.lstrip('/')
|
||||
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||
def request(self, method, url, *args, **kwargs):
|
||||
if ':' not in url:
|
||||
url = 'http://testserver/' + url.lstrip('/')
|
||||
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||
|
||||
|
||||
class APIRequestFactory(DjangoRequestFactory):
|
||||
|
@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase):
|
|||
class APITestCase(testcases.TestCase):
|
||||
client_class = APIClient
|
||||
|
||||
def _pre_setup(self):
|
||||
super(APITestCase, self)._pre_setup()
|
||||
self.requests = DjangoTestSession()
|
||||
@property
|
||||
def requests(self):
|
||||
if not hasattr(self, '_requests'):
|
||||
assert requests is not None, 'requests must be installed'
|
||||
self._requests = DjangoTestSession()
|
||||
return self._requests
|
||||
|
||||
|
||||
class APISimpleTestCase(testcases.SimpleTestCase):
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.test import override_settings
|
||||
|
||||
from rest_framework.compat import requests
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase
|
||||
from rest_framework.views import APIView
|
||||
|
@ -37,7 +40,7 @@ class Root(APIView):
|
|||
class Headers(APIView):
|
||||
def get(self, request):
|
||||
headers = {
|
||||
key[5:]: value
|
||||
key[5:].replace('_', '-'): value
|
||||
for key, value in request.META.items()
|
||||
if key.startswith('HTTP_')
|
||||
}
|
||||
|
@ -53,6 +56,7 @@ urlpatterns = [
|
|||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(requests, 'requests not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||
class RequestsClientTests(APITestCase):
|
||||
def test_get_request(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user