mirror of
https://github.com/django/daphne.git
synced 2024-11-22 07:56:34 +03:00
249 lines
9.7 KiB
Python
249 lines
9.7 KiB
Python
from urllib import parse
|
|
import requests
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import unittest
|
|
|
|
from daphne.test_utils import TestApplication
|
|
|
|
|
|
class DaphneTestCase(unittest.TestCase):
|
|
"""
|
|
Base class for Daphne integration test cases.
|
|
|
|
Boots up a copy of Daphne on a test port and sends it a request, and
|
|
retrieves the response. Uses a custom ASGI application and temporary files
|
|
to store/retrieve the request/response messages.
|
|
"""
|
|
|
|
def port_in_use(self, port):
|
|
"""
|
|
Tests if a port is in use on the local machine.
|
|
"""
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
s.bind(("127.0.0.1", port))
|
|
except socket.error as e:
|
|
if e.errno in [13, 98]:
|
|
return True
|
|
else:
|
|
raise
|
|
else:
|
|
return False
|
|
finally:
|
|
s.close()
|
|
|
|
def run_daphne(self, method, path, params, data, responses, headers=None, timeout=1):
|
|
"""
|
|
Runs Daphne with the given request callback (given the base URL)
|
|
and response messages.
|
|
"""
|
|
# Store setup info
|
|
TestApplication.clear_storage()
|
|
TestApplication.save_setup(
|
|
response_messages=responses,
|
|
)
|
|
# Find a free port
|
|
for i in range(11200, 11300):
|
|
if not self.port_in_use(i):
|
|
port = i
|
|
break
|
|
else:
|
|
raise RuntimeError("Cannot find a free port to test on")
|
|
# Launch daphne on that port
|
|
process = subprocess.Popen(["daphne", "-p", str(port), "daphne.test_utils:TestApplication"])
|
|
try:
|
|
for _ in range(100):
|
|
time.sleep(0.1)
|
|
if self.port_in_use(port):
|
|
break
|
|
else:
|
|
raise RuntimeError("Daphne never came up.")
|
|
# Send it the request
|
|
url = "http://127.0.0.1:%i%s" % (port, path)
|
|
response = requests.request(method, url, params=params, data=data, headers=headers, timeout=timeout)
|
|
finally:
|
|
# Shut down daphne
|
|
process.terminate()
|
|
# Load the information
|
|
inner_result = TestApplication.load_result()
|
|
# Return the inner result and the response
|
|
return inner_result, response
|
|
|
|
def run_daphne_request(self, method, path, params=None, data=None, headers=None):
|
|
"""
|
|
Convenience method for just testing request handling.
|
|
Returns (scope, messages)
|
|
"""
|
|
if headers is not None:
|
|
headers = dict(headers)
|
|
inner_result, _ = self.run_daphne(
|
|
method=method,
|
|
path=path,
|
|
params=params,
|
|
data=data,
|
|
headers=headers,
|
|
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
|
|
)
|
|
return inner_result["scope"], inner_result["messages"]
|
|
|
|
def tearDown(self):
|
|
"""
|
|
Ensures any storage files are cleared.
|
|
"""
|
|
TestApplication.clear_storage()
|
|
|
|
def assert_is_ip_address(self, address):
|
|
"""
|
|
Tests whether a given address string is a valid IPv4 or IPv6 address.
|
|
"""
|
|
try:
|
|
socket.inet_aton(address)
|
|
except socket.error:
|
|
self.fail("'%s' is not a valid IP address." % address)
|
|
|
|
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
|
|
"""
|
|
Asserts that all required_keys are in actual_keys, and that there
|
|
are no keys in actual_keys that aren't required or optional.
|
|
"""
|
|
present_keys = set(actual_keys)
|
|
# Make sure all required keys are present
|
|
self.assertTrue(required_keys <= present_keys)
|
|
# Assert that no other keys are present
|
|
self.assertEqual(
|
|
set(),
|
|
present_keys - required_keys - optional_keys,
|
|
)
|
|
|
|
def assert_valid_path(self, path, request_path):
|
|
"""
|
|
Checks the path is valid and already url-decoded.
|
|
"""
|
|
self.assertIsInstance(path, str)
|
|
self.assertEqual(path, request_path)
|
|
# Assert that it's already url decoded
|
|
self.assertEqual(path, parse.unquote(path))
|
|
|
|
def assert_valid_address_and_port(self, host):
|
|
"""
|
|
Asserts the value is a valid (host, port) tuple.
|
|
"""
|
|
address, port = host
|
|
self.assertIsInstance(address, str)
|
|
self.assert_is_ip_address(address)
|
|
self.assertIsInstance(port, int)
|
|
|
|
|
|
# class ASGIHTTPTestCase(ASGITestCaseBase):
|
|
# """
|
|
# Test case with helpers for verifying HTTP channel messages
|
|
# """
|
|
|
|
|
|
# def assert_valid_http_response_message(self, message, response):
|
|
# self.assertTrue(message)
|
|
# self.assertTrue(response.startswith(b"HTTP"))
|
|
|
|
# status_code_bytes = str(message["status"]).encode("ascii")
|
|
# self.assertIn(status_code_bytes, response)
|
|
|
|
# if "content" in message:
|
|
# self.assertIn(message["content"], response)
|
|
|
|
# # Check that headers are in the given order.
|
|
# # N.b. HTTP spec only enforces that the order of header values is kept, but
|
|
# # the ASGI spec requires that order of all headers is kept. This code
|
|
# # checks conformance with the stricter ASGI spec.
|
|
# if "headers" in message:
|
|
# for name, value in message["headers"]:
|
|
# expected_header = factories.header_line(name, value)
|
|
# # Daphne or Twisted turn our lower cased header names ('foo-bar') into title
|
|
# # case ('Foo-Bar'). So technically we want to to match that the header name is
|
|
# # present while ignoring casing, and want to ensure the value is present without
|
|
# # altered casing. The approach below does this well enough.
|
|
# self.assertIn(expected_header.lower(), response.lower())
|
|
# self.assertIn(value.encode("ascii"), response)
|
|
|
|
|
|
# class ASGIWebSocketTestCase(ASGITestCaseBase):
|
|
# """
|
|
# Test case with helpers for verifying WebSocket channel messages
|
|
# """
|
|
|
|
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
|
|
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
|
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
|
# self.assertIn(body, response)
|
|
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
|
|
|
# def assert_websocket_denied(self, response):
|
|
# self.assertIn(b"HTTP/1.1 403", response)
|
|
|
|
# def assert_valid_websocket_connect_message(
|
|
# self, channel_message, request_path="/", request_params=None, request_headers=None):
|
|
# """
|
|
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
|
# """
|
|
|
|
# self.assertTrue(channel_message)
|
|
|
|
# self.assert_presence_of_message_keys(
|
|
# channel_message.keys(),
|
|
# {"reply_channel", "path", "headers", "order"},
|
|
# {"scheme", "query_string", "root_path", "client", "server"})
|
|
|
|
# # == Assertions about required channel_message fields ==
|
|
# self.assert_valid_reply_channel(channel_message["reply_channel"])
|
|
# self.assert_valid_path(channel_message["path"], request_path)
|
|
|
|
# order = channel_message["order"]
|
|
# self.assertIsInstance(order, int)
|
|
# self.assertEqual(order, 0)
|
|
|
|
# # Ordering of header names is not important, but the order of values for a header
|
|
# # name is. To assert whether that order is kept, we transform the request
|
|
# # headers and the channel message headers into a set
|
|
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
|
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
|
# # get one string per header field with values separated by comma.
|
|
# transformed_request_headers = defaultdict(list)
|
|
# for name, value in (request_headers or []):
|
|
# expected_name = name.lower().strip().encode("ascii")
|
|
# expected_value = value.strip().encode("ascii")
|
|
# transformed_request_headers[expected_name].append(expected_value)
|
|
# final_request_headers = {
|
|
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
|
|
# }
|
|
|
|
# # Websockets carry a lot of additional header fields, so instead of verifying that
|
|
# # headers look exactly like expected, we just check that the expected header fields
|
|
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
|
# # and not tested for.
|
|
# assert final_request_headers.issubset(set(channel_message["headers"]))
|
|
|
|
# # == Assertions about optional channel_message fields ==
|
|
# scheme = channel_message.get("scheme")
|
|
# if scheme:
|
|
# self.assertIsInstance(scheme, six.text_type)
|
|
# self.assertIn(scheme, ["ws", "wss"])
|
|
|
|
# query_string = channel_message.get("query_string")
|
|
# if query_string:
|
|
# # Assert that query_string is a byte string and still url encoded
|
|
# self.assertIsInstance(query_string, six.binary_type)
|
|
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
|
|
|
# root_path = channel_message.get("root_path")
|
|
# if root_path is not None:
|
|
# self.assertIsInstance(root_path, six.text_type)
|
|
|
|
# client = channel_message.get("client")
|
|
# if client is not None:
|
|
# self.assert_valid_address_and_port(channel_message["client"])
|
|
|
|
# server = channel_message.get("server")
|
|
# if server is not None:
|
|
# self.assert_valid_address_and_port(channel_message["server"])
|