daphne/tests/http_base.py

266 lines
9.4 KiB
Python
Raw Normal View History

2017-11-26 05:23:54 +03:00
import socket
import struct
2017-11-26 05:23:54 +03:00
import time
import unittest
2017-11-29 05:00:15 +03:00
from http.client import HTTPConnection
from urllib import parse
2017-11-26 05:23:54 +03:00
from daphne.testing import DaphneTestingInstance, TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
### Plain HTTP helpers
2018-08-27 05:27:32 +03:00
def run_daphne_http(
self, method, path, params, body, responses, headers=None, timeout=1, xff=False
):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
with DaphneTestingInstance(xff=xff) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
2017-11-26 10:19:27 +03:00
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
2017-11-26 10:19:27 +03:00
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers
2017-11-26 10:19:27 +03:00
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
2017-11-26 10:19:27 +03:00
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
conn.endheaders(message_body=body)
else:
conn.endheaders()
2017-11-26 11:06:23 +03:00
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
2018-08-27 05:27:32 +03:00
raise RuntimeError(
"Daphne timed out handling request, no exception found."
)
# Return scope, messages, response
2018-08-27 05:27:32 +03:00
return test_app.get_received() + (response,)
2017-11-26 05:23:54 +03:00
def run_daphne_raw(self, data, *, responses=None, timeout=1):
"""
Runs Daphne and sends it the given raw bytestring over a socket.
Accepts list of response messages the application will reply with.
Returns what Daphne sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
if responses is not None:
test_app.add_send_messages(responses)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s.send(data)
try:
return s.recv(1000000)
except socket.timeout:
2018-08-27 05:27:32 +03:00
raise RuntimeError(
"Daphne timed out handling raw request, no exception found."
)
2018-08-27 05:27:32 +03:00
def run_daphne_request(
self, method, path, params=None, body=None, headers=None, xff=False
):
2017-11-26 05:23:54 +03:00
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
scope, messages, _ = self.run_daphne_http(
2017-11-26 05:23:54 +03:00
method=method,
path=path,
params=params,
2017-11-26 10:19:27 +03:00
body=body,
2017-11-26 05:23:54 +03:00
headers=headers,
2017-11-26 10:19:27 +03:00
xff=xff,
2017-11-29 04:28:35 +03:00
responses=[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"OK"},
2017-11-29 04:28:35 +03:00
],
2017-11-26 05:23:54 +03:00
)
return scope, messages
2017-11-26 05:23:54 +03:00
2017-11-26 11:06:23 +03:00
def run_daphne_response(self, response_messages):
"""
Convenience method for just testing response handling.
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
2018-08-27 05:27:32 +03:00
method="GET", path="/", params={}, body=b"", responses=response_messages
2017-11-26 11:06:23 +03:00
)
return response
### WebSocket helpers
2018-08-27 05:27:32 +03:00
def websocket_handshake(
self,
test_app,
path="/",
params=None,
headers=None,
subprotocols=None,
timeout=1,
):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
You'll need to inject an accept or reject message before this
to let it complete.
"""
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
2018-08-27 05:27:32 +03:00
headers.extend(
[
(b"Host", b"example.com"),
(b"Upgrade", b"websocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
(b"Sec-WebSocket-Version", b"13"),
(b"Origin", b"http://example.com"),
2018-08-27 05:27:32 +03:00
]
)
if subprotocols:
headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
conn.endheaders()
# Read out the response
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Check we got a good response code
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
2018-08-27 05:27:32 +03:00
response_headers = dict((n.lower(), v) for n, v in response.getheaders())
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
return conn.sock, response_headers.get("sec-websocket-protocol", None)
def websocket_send_frame(self, sock, value):
"""
Sends a WebSocket text or binary frame. Cannot handle long frames.
"""
# Header and text opcode
if isinstance(value, str):
frame = b"\x81"
value = value.encode("utf8")
else:
frame = b"\x82"
# Length plus masking signal bit
frame += struct.pack("!B", len(value) | 0b10000000)
# Mask badly
frame += b"\0\0\0\0"
# Payload
frame += value
sock.sendall(frame)
def receive_from_socket(self, sock, length, timeout=1):
"""
Receives the given amount of bytes from the socket, or times out.
"""
buf = b""
started = time.time()
while len(buf) < length:
buf += sock.recv(length - len(buf))
time.sleep(0.001)
if time.time() - started > timeout:
raise ValueError("Timed out reading from socket")
return buf
def websocket_receive_frame(self, sock):
"""
Receives a WebSocket frame. Cannot handle long frames.
"""
# Read header byte
# TODO: Proper receive buffer handling
opcode = self.receive_from_socket(sock, 1)
if opcode in [b"\x81", b"\x82"]:
# Read length
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
# Read payload
payload = self.receive_from_socket(sock, length)
if opcode == b"\x81":
payload = payload.decode("utf8")
return payload
else:
raise ValueError("Unknown websocket opcode: %r" % opcode)
### Assertions and test management
2017-11-26 05:23:54 +03:00
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.delete_setup()
TestApplication.delete_result()
2017-11-26 05:23:54 +03:00
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
2018-08-27 05:27:32 +03:00
self.assertEqual(set(), present_keys - required_keys - optional_keys)
2017-11-26 05:23:54 +03:00
def assert_valid_path(self, path):
2017-11-26 05:23:54 +03:00
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)