daphne/tests/test_websocket.py
2022-10-07 13:22:40 +02:00

315 lines
13 KiB
Python

import collections
import time
from urllib import parse
import http_strategies
from http_base import DaphneTestCase, DaphneTestingInstance
from hypothesis import given, settings
from daphne.testing import BaseDaphneTestingInstance
class TestWebsocket(DaphneTestCase):
"""
Tests WebSocket handshake, send and receive.
"""
def assert_valid_websocket_scope(
self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={
"asgi",
"type",
"path",
"raw_path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
actual_keys=scope.keys(),
)
self.assertEqual(scope["asgi"]["version"], "3.0")
# Check that it is the right type
self.assertEqual(scope["type"], "websocket")
# Path
self.assert_valid_path(scope["path"])
# Scheme
self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers.setdefault(name, [])
# Make sure to split out any headers collapsed with commas
for bit in value.split(b","):
if bit.strip():
transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list)
for name, value in headers or []:
expected_name = name.lower().strip()
expected_value = value.strip()
# Make sure to split out any headers collapsed with commas
transformed_request_headers.setdefault(expected_name, [])
for bit in expected_value.split(b","):
if bit.strip():
transformed_request_headers[expected_name].append(bit.strip())
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
# Subprotocols
scope_subprotocols = scope.get("subprotocols", [])
if scope_subprotocols:
assert all(isinstance(x, str) for x in scope_subprotocols)
if subprotocols:
assert sorted(scope_subprotocols) == sorted(subprotocols)
def assert_valid_websocket_connect_message(self, message):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
)
# Check that it is the right type
self.assertEqual(message["type"], "websocket.connect")
def test_accept(self):
"""
Tests we can open and accept a socket.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope)
self.assert_valid_websocket_connect_message(messages[0])
def test_reject(self):
"""
Tests we can reject a socket and it won't complete the handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.close"}])
with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app)
def test_subprotocols(self):
"""
Tests that we can ask for subprotocols and then select one.
"""
subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[{"type": "websocket.accept", "subprotocol": "proto2"}]
)
_, subprotocol = self.websocket_handshake(
test_app, subprotocols=subprotocols
)
# Validate the scope and messages we got
assert subprotocol == "proto2"
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
self.assert_valid_websocket_connect_message(messages[0])
def test_xff(self):
"""
Tests that X-Forwarded-For headers get parsed right
"""
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
with DaphneTestingInstance(xff=True) as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, headers=headers)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope)
self.assert_valid_websocket_connect_message(messages[0])
assert scope["client"] == ["10.1.2.3", 80]
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
)
@settings(max_examples=5, deadline=2000)
def test_http_bits(self, request_path, request_params, request_headers):
"""
Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(
test_app,
path=parse.quote(request_path),
params=request_params,
headers=request_headers,
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(
scope, path=request_path, params=request_params, headers=request_headers
)
self.assert_valid_websocket_connect_message(messages[0])
def test_raw_path(self):
"""
Tests that /foo%2Fbar produces raw_path and a decoded path
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, path="/foo%2Fbar")
# Validate the scope and messages we got
scope, _ = test_app.get_received()
self.assertEqual(scope["path"], "/foo/bar")
self.assertEqual(scope["raw_path"], b"/foo%2Fbar")
def test_text_frames(self):
"""
Tests we can send and receive text frames.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages(
[{"type": "websocket.send", "text": "here be dragons 🐉"}]
)
# Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {
"type": "websocket.receive",
"text": "what is here? 🌍",
}
def test_binary_frames(self):
"""
Tests we can send and receive binary frames with things that are very
much not valid UTF-8.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages(
[{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
)
# Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {
"type": "websocket.receive",
"bytes": b"what is here? \xe2",
}
def test_http_timeout(self):
"""
Tests that the HTTP timeout doesn't kick in for WebSockets
"""
with DaphneTestingInstance(http_timeout=1) as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Wait 2 seconds
time.sleep(2)
# Prep frame for it to send
test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
# Send it a frame
self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "cake"
def test_application_checker_handles_asyncio_cancellederror(self):
with CancellingTestingInstance() as app:
# Connect to the websocket app, it will immediately raise
# asyncio.CancelledError
sock, _ = self.websocket_handshake(app)
# Disconnect from the socket
sock.close()
# Wait for application_checker to clean up the applications for
# disconnected clients, and for the server to be stopped.
time.sleep(3)
# Make sure we received either no error, or a ConnectionsNotEmpty
while not app.process.errors.empty():
err, _tb = app.process.errors.get()
if not isinstance(err, ConnectionsNotEmpty):
raise err
self.fail(
"Server connections were not cleaned up after an asyncio.CancelledError was raised"
)
async def cancelling_application(scope, receive, send):
import asyncio
from twisted.internet import reactor
# Stop the server after a short delay so that the teardown is run.
reactor.callLater(2, reactor.stop)
await send({"type": "websocket.accept"})
raise asyncio.CancelledError()
class ConnectionsNotEmpty(Exception):
pass
class CancellingTestingInstance(BaseDaphneTestingInstance):
def __init__(self):
super().__init__(application=cancelling_application)
def process_teardown(self):
import multiprocessing
# Get a hold of the enclosing DaphneProcess (we're currently running in
# the same process as the application).
proc = multiprocessing.current_process()
# By now the (only) socket should have disconnected, and the
# application_checker should have run. If there are any connections
# still, it means that the application_checker did not clean them up.
if proc.server.connections:
raise ConnectionsNotEmpty()