daphne/tests/http_base.py
2018-03-04 09:48:33 -08:00

259 lines
9.4 KiB
Python

import socket
import struct
import time
import unittest
from http.client import HTTPConnection
from urllib import parse
from daphne.testing import DaphneTestingInstance, TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
### Plain HTTP helpers
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
with DaphneTestingInstance(xff=xff) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
# Make sure path is urlquoted and add any params
path = parse.quote(path)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers (encoding any non-safe values as best we can)
if headers:
for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
conn.endheaders(message_body=body)
else:
conn.endheaders()
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Return scope, messages, response
return test_app.get_received() + (response, )
def run_daphne_raw(self, data, timeout=1):
"""
Runs daphne and sends it the given raw bytestring over a socket. Returns what it sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s.send(data)
try:
return s.recv(1000000)
except socket.timeout:
raise RuntimeError("Daphne timed out handling raw request, no exception found.")
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
scope, messages, _ = self.run_daphne_http(
method=method,
path=path,
params=params,
body=body,
headers=headers,
xff=xff,
responses=[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"OK"},
],
)
return scope, messages
def run_daphne_response(self, response_messages):
"""
Convenience method for just testing response handling.
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
method="GET",
path="/",
params={},
body=b"",
responses=response_messages,
)
return response
### WebSocket helpers
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
You'll need to inject an accept or reject message before this
to let it complete.
"""
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
# Make sure path is urlquoted and add any params
path = parse.quote(path)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
headers.extend([
("Host", "example.com"),
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
("Sec-WebSocket-Version", "13"),
("Origin", "http://example.com")
])
if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
conn.endheaders()
# Read out the response
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Check we got a good response code
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
response_headers = dict(
(n.lower(), v)
for n, v in response.getheaders()
)
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
return conn.sock, response_headers.get("sec-websocket-protocol", None)
def websocket_send_frame(self, sock, value):
"""
Sends a WebSocket text or binary frame. Cannot handle long frames.
"""
# Header and text opcode
if isinstance(value, str):
frame = b"\x81"
value = value.encode("utf8")
else:
frame = b"\x82"
# Length plus masking signal bit
frame += struct.pack("!B", len(value) | 0b10000000)
# Mask badly
frame += b"\0\0\0\0"
# Payload
frame += value
sock.sendall(frame)
def receive_from_socket(self, sock, length, timeout=1):
"""
Receives the given amount of bytes from the socket, or times out.
"""
buf = b""
started = time.time()
while len(buf) < length:
buf += sock.recv(length - len(buf))
time.sleep(0.001)
if time.time() - started > timeout:
raise ValueError("Timed out reading from socket")
return buf
def websocket_receive_frame(self, sock):
"""
Receives a WebSocket frame. Cannot handle long frames.
"""
# Read header byte
# TODO: Proper receive buffer handling
opcode = self.receive_from_socket(sock, 1)
if opcode in [b"\x81", b"\x82"]:
# Read length
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
# Read payload
payload = self.receive_from_socket(sock, length)
if opcode == b"\x81":
payload = payload.decode("utf8")
return payload
else:
raise ValueError("Unknown websocket opcode: %r" % opcode)
### Assertions and test management
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.delete_setup()
TestApplication.delete_result()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(
set(),
present_keys - required_keys - optional_keys,
)
def assert_valid_path(self, path, request_path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)