daphne/tests/http_base.py
2017-11-26 00:06:23 -08:00

264 lines
10 KiB
Python

from urllib import parse
from http.client import HTTPConnection
import socket
import subprocess
import time
import unittest
from daphne.test_utils import TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
def port_in_use(self, port):
"""
Tests if a port is in use on the local machine.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(("127.0.0.1", port))
except socket.error as e:
if e.errno in [13, 98]:
return True
else:
raise
else:
return False
finally:
s.close()
def run_daphne(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
# Store setup info
TestApplication.clear_storage()
TestApplication.save_setup(
response_messages=responses,
)
# Find a free port
for i in range(11200, 11300):
if not self.port_in_use(i):
port = i
break
else:
raise RuntimeError("Cannot find a free port to test on")
# Launch daphne on that port
daphne_args = ["daphne", "-p", str(port), "-v", "0"]
if xff:
# Optionally enable X-Forwarded-For support.
daphne_args += ["--proxy-headers"]
process = subprocess.Popen(daphne_args + ["daphne.test_utils:TestApplication"])
try:
for _ in range(100):
time.sleep(0.1)
if self.port_in_use(port):
break
else:
raise RuntimeError("Daphne never came up.")
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection("127.0.0.1", port, timeout=timeout)
# Make sure path is urlquoted and add any params
path = parse.quote(path)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers (encoding any non-safe values as best we can)
if headers:
for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
conn.endheaders(message_body=body)
else:
conn.endheaders()
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
try:
exception_result = TestApplication.load_result()
except OSError:
raise RuntimeError("Daphne timed out handling request, no result file")
else:
if "exception" in exception_result:
raise exception_result["exception"]
else:
raise RuntimeError("Daphne timed out handling request, no exception found: %r" % exception_result)
finally:
# Shut down daphne
process.terminate()
# Load the information
inner_result = TestApplication.load_result()
# Return the inner result and the response
return inner_result, response
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
inner_result, _ = self.run_daphne(
method=method,
path=path,
params=params,
body=body,
headers=headers,
xff=xff,
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
)
return inner_result["scope"], inner_result["messages"]
def run_daphne_response(self, response_messages):
"""
Convenience method for just testing response handling.
Returns (scope, messages)
"""
_, response = self.run_daphne(
method="GET",
path="/",
params={},
body=b"",
responses=response_messages,
)
return response
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.clear_storage()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(
set(),
present_keys - required_keys - optional_keys,
)
def assert_valid_path(self, path, request_path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
# class ASGIWebSocketTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying WebSocket channel messages
# """
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
# self.assertIn(body, response)
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
# def assert_websocket_denied(self, response):
# self.assertIn(b"HTTP/1.1 403", response)
# def assert_valid_websocket_connect_message(
# self, channel_message, request_path="/", request_params=None, request_headers=None):
# """
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
# """
# self.assertTrue(channel_message)
# self.assert_presence_of_message_keys(
# channel_message.keys(),
# {"reply_channel", "path", "headers", "order"},
# {"scheme", "query_string", "root_path", "client", "server"})
# # == Assertions about required channel_message fields ==
# self.assert_valid_reply_channel(channel_message["reply_channel"])
# self.assert_valid_path(channel_message["path"], request_path)
# order = channel_message["order"]
# self.assertIsInstance(order, int)
# self.assertEqual(order, 0)
# # Ordering of header names is not important, but the order of values for a header
# # name is. To assert whether that order is kept, we transform the request
# # headers and the channel message headers into a set
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# # get one string per header field with values separated by comma.
# transformed_request_headers = defaultdict(list)
# for name, value in (request_headers or []):
# expected_name = name.lower().strip().encode("ascii")
# expected_value = value.strip().encode("ascii")
# transformed_request_headers[expected_name].append(expected_value)
# final_request_headers = {
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
# }
# # Websockets carry a lot of additional header fields, so instead of verifying that
# # headers look exactly like expected, we just check that the expected header fields
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# # and not tested for.
# assert final_request_headers.issubset(set(channel_message["headers"]))
# # == Assertions about optional channel_message fields ==
# scheme = channel_message.get("scheme")
# if scheme:
# self.assertIsInstance(scheme, six.text_type)
# self.assertIn(scheme, ["ws", "wss"])
# query_string = channel_message.get("query_string")
# if query_string:
# # Assert that query_string is a byte string and still url encoded
# self.assertIsInstance(query_string, six.binary_type)
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# root_path = channel_message.get("root_path")
# if root_path is not None:
# self.assertIsInstance(root_path, six.text_type)
# client = channel_message.get("client")
# if client is not None:
# self.assert_valid_address_and_port(channel_message["client"])
# server = channel_message.get("server")
# if server is not None:
# self.assert_valid_address_and_port(channel_message["server"])