Add websocket tests to make sure everything important is covered.

This commit is contained in:
Andrew Godwin 2017-11-27 00:00:34 -08:00
parent 1ca1c67032
commit 567c27504d
8 changed files with 485 additions and 532 deletions

View File

@ -1,3 +1,5 @@
from concurrent.futures import CancelledError
import logging
import os import os
import pickle import pickle
import tempfile import tempfile
@ -17,21 +19,29 @@ class TestApplication:
self.messages = [] self.messages = []
async def __call__(self, send, receive): async def __call__(self, send, receive):
# Load setup info
setup = self.load_setup()
# Receive input and send output # Receive input and send output
logging.debug("test app coroutine alive")
try: try:
for _ in range(setup["receive_messages"]): while True:
# Receive a message and save it into the result store
self.messages.append(await receive()) self.messages.append(await receive())
for message in setup["response_messages"]: logging.debug("test app received %r", self.messages[-1])
await send(message) self.save_result(self.scope, self.messages)
# See if there are any messages to send back
setup = self.load_setup()
self.delete_setup()
for message in setup["response_messages"]:
await send(message)
logging.debug("test app sent %r", message)
except Exception as e: except Exception as e:
self.save_exception(e) if isinstance(e, CancelledError):
else: # Don't catch task-cancelled errors!
self.save_result() raise
else:
self.save_exception(e)
@classmethod @classmethod
def save_setup(cls, response_messages, receive_messages=1): def save_setup(cls, response_messages):
""" """
Stores setup information. Stores setup information.
""" """
@ -39,7 +49,6 @@ class TestApplication:
pickle.dump( pickle.dump(
{ {
"response_messages": response_messages, "response_messages": response_messages,
"receive_messages": receive_messages,
}, },
fh, fh,
) )
@ -49,29 +58,34 @@ class TestApplication:
""" """
Returns setup details. Returns setup details.
""" """
with open(cls.setup_storage, "rb") as fh: try:
return pickle.load(fh) with open(cls.setup_storage, "rb") as fh:
return pickle.load(fh)
except FileNotFoundError:
return {"response_messages": []}
def save_result(self): @classmethod
def save_result(cls, scope, messages):
""" """
Saves details of what happened to the result storage. Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow. We could use pickle here, but that seems wrong, still, somehow.
""" """
with open(self.result_storage, "wb") as fh: with open(cls.result_storage, "wb") as fh:
pickle.dump( pickle.dump(
{ {
"scope": self.scope, "scope": scope,
"messages": self.messages, "messages": messages,
}, },
fh, fh,
) )
def save_exception(self, exception): @classmethod
def save_exception(cls, exception):
""" """
Saves details of what happened to the result storage. Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow. We could use pickle here, but that seems wrong, still, somehow.
""" """
with open(self.result_storage, "wb") as fh: with open(cls.result_storage, "wb") as fh:
pickle.dump( pickle.dump(
{ {
"exception": exception, "exception": exception,
@ -88,14 +102,20 @@ class TestApplication:
return pickle.load(fh) return pickle.load(fh)
@classmethod @classmethod
def clear_storage(cls): def delete_setup(cls):
""" """
Clears storage files. Clears setup storage files.
""" """
try: try:
os.unlink(cls.setup_storage) os.unlink(cls.setup_storage)
except OSError: except OSError:
pass pass
@classmethod
def delete_result(cls):
"""
Clears result storage files.
"""
try: try:
os.unlink(cls.result_storage) os.unlink(cls.result_storage)
except OSError: except OSError:

View File

@ -1,128 +0,0 @@
from __future__ import unicode_literals
import six
from six.moves.urllib import parse
from asgiref.inmemory import ChannelLayer
from twisted.test import proto_helpers
from daphne.http_protocol import HTTPFactory
def message_for_request(method, path, params=None, headers=None, body=None):
"""
Constructs a HTTP request according to the given parameters, runs
that through daphne and returns the emitted channel message.
"""
request = _build_request(method, path, params, headers, body)
message, factory, transport = _run_through_daphne(request, "http.request")
return message
def response_for_message(message):
"""
Returns the raw HTTP response that Daphne constructs when sending a reply
to a HTTP request.
The current approach actually first builds a HTTP request (similar to
message_for_request) because we need a valid reply channel. I'm sure
this can be streamlined, but it works for now.
"""
request = _build_request("GET", "/")
request_message, factory, transport = _run_through_daphne(request, "http.request")
factory.dispatch_reply(request_message["reply_channel"], message)
return transport.value()
def _build_request(method, path, params=None, headers=None, body=None):
"""
Takes request parameters and returns a byte string of a valid HTTP/1.1 request.
We really shouldn't manually build a HTTP request, and instead try to capture
what e.g. urllib or requests would do. But that is non-trivial, so meanwhile
we hope that our request building doesn't mask any errors.
This code is messy, because urllib behaves rather different between Python 2
and 3. Readability is further obstructed by the fact that Python 3.4 doesn't
support % formatting for bytes, so we need to concat everything.
If we run into more issues with this, the python-future library has a backport
of Python 3's urllib.
:param method: ASCII string of HTTP method.
:param path: unicode string of URL path.
:param params: List of two-tuples of bytestrings, ready for consumption for
urlencode. Encode to utf8 if necessary.
:param headers: List of two-tuples ASCII strings of HTTP header, value.
:param body: ASCII string of request body.
ASCII string is short for a unicode string containing only ASCII characters,
or a byte string with ASCII encoding.
"""
if headers is None:
headers = []
else:
headers = headers[:]
if six.PY3:
quoted_path = parse.quote(path)
if params:
quoted_path += "?" + parse.urlencode(params)
quoted_path = quoted_path.encode("ascii")
else:
quoted_path = parse.quote(path.encode("utf8"))
if params:
quoted_path += b"?" + parse.urlencode(params)
request = method.encode("ascii") + b" " + quoted_path + b" HTTP/1.1\r\n"
for name, value in headers:
request += header_line(name, value)
request += b"\r\n"
if body:
request += body.encode("ascii")
return request
def build_websocket_upgrade(path, params, headers):
ws_headers = [
("Host", "somewhere.com"),
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
("Sec-WebSocket-Protocol", "chat, superchat"),
("Sec-WebSocket-Version", "13"),
("Origin", "http://example.com")
]
return _build_request("GET", path, params, headers=headers + ws_headers, body=None)
def header_line(name, value):
"""
Given a header name and value, returns the line to use in a HTTP request or response.
"""
return name.encode("ascii") + b": " + value.encode("ascii") + b"\r\n"
def _run_through_daphne(request, channel_name):
"""
Returns Daphne's channel message for a given request.
This helper requires a fair bit of scaffolding and can certainly be improved,
but it works for now.
"""
channel_layer = ChannelLayer()
factory = HTTPFactory(channel_layer, send_channel="test!")
proto = factory.buildProtocol(("127.0.0.1", 0))
transport = proto_helpers.StringTransport()
proto.makeConnection(transport)
proto.dataReceived(request)
_, message = channel_layer.receive([channel_name])
return message, factory, transport
def content_length_header(body):
"""
Returns an appropriate Content-Length HTTP header for a given body.
"""
return "Content-Length", six.text_type(len(body))

View File

@ -1,246 +0,0 @@
# coding: utf8
from __future__ import unicode_literals
from hypothesis import assume, given, strategies, settings
from twisted.test import proto_helpers
from asgiref.inmemory import ChannelLayer
from daphne.http_protocol import HTTPFactory
from daphne.tests import http_strategies, testcases, factories
class WebSocketConnection(object):
"""
Helper class that makes it easier to test Dahpne's WebSocket support.
"""
def __init__(self):
self.last_message = None
self.channel_layer = ChannelLayer()
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
self.transport = proto_helpers.StringTransport()
self.proto.makeConnection(self.transport)
def receive(self, request):
"""
Low-level method to let Daphne handle HTTP/WebSocket data
"""
self.proto.dataReceived(request)
_, self.last_message = self.channel_layer.receive(["websocket.connect"])
return self.last_message
def send(self, content):
"""
Method to respond with a channel message
"""
if self.last_message is None:
# Auto-connect for convenience.
self.connect()
self.factory.dispatch_reply(self.last_message["reply_channel"], content)
response = self.transport.value()
self.transport.clear()
return response
def connect(self, path="/", params=None, headers=None):
"""
High-level method to perform the WebSocket handshake
"""
request = factories.build_websocket_upgrade(path, params, headers or [])
message = self.receive(request)
return message
class TestHandshake(testcases.ASGIWebSocketTestCase):
"""
Tests for the WebSocket handshake
"""
def test_minimal(self):
message = WebSocketConnection().connect()
self.assert_valid_websocket_connect_message(message)
@given(
path=http_strategies.http_path(),
params=http_strategies.query_params(),
headers=http_strategies.headers(),
)
@settings(perform_health_check=False)
def test_connection(self, path, params, headers):
message = WebSocketConnection().connect(path, params, headers)
self.assert_valid_websocket_connect_message(message, path, params, headers)
class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
"""
Tests that, essentially, try to translate the send/close/accept section of the spec into code.
"""
def test_empty_accept(self):
response = WebSocketConnection().send({"accept": True})
self.assert_websocket_upgrade(response)
@given(text=http_strategies.http_body())
def test_accept_and_text(self, text):
response = WebSocketConnection().send({"accept": True, "text": text})
self.assert_websocket_upgrade(response, text.encode("ascii"))
@given(data=http_strategies.binary_payload())
def test_accept_and_bytes(self, data):
response = WebSocketConnection().send({"accept": True, "bytes": data})
self.assert_websocket_upgrade(response, data)
def test_accept_false(self):
response = WebSocketConnection().send({"accept": False})
self.assert_websocket_denied(response)
def test_accept_false_with_text(self):
"""
Tests that even if text is given, the connection is denied.
We can't easily use Hypothesis to generate data for this test because it's
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
"""
text = "foobar"
response = WebSocketConnection().send({"accept": False, "text": text})
self.assert_websocket_denied(response)
self.assertNotIn(text.encode("ascii"), response)
def test_accept_false_with_bytes(self):
"""
Tests that even if data is given, the connection is denied.
We can't easily use Hypothesis to generate data for this test because it's
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
"""
data = b"foobar"
response = WebSocketConnection().send({"accept": False, "bytes": data})
self.assert_websocket_denied(response)
self.assertNotIn(data, response)
@given(text=http_strategies.http_body())
def test_just_text(self, text):
assume(len(text) > 0)
# If content is sent, accept=True is implied.
response = WebSocketConnection().send({"text": text})
self.assert_websocket_upgrade(response, text.encode("ascii"))
@given(data=http_strategies.binary_payload())
def test_just_bytes(self, data):
assume(len(data) > 0)
# If content is sent, accept=True is implied.
response = WebSocketConnection().send({"bytes": data})
self.assert_websocket_upgrade(response, data)
def test_close_boolean(self):
response = WebSocketConnection().send({"close": True})
self.assert_websocket_denied(response)
@given(number=strategies.integers(min_value=1))
def test_close_integer(self, number):
response = WebSocketConnection().send({"close": number})
self.assert_websocket_denied(response)
@given(text=http_strategies.http_body())
def test_close_with_text(self, text):
assume(len(text) > 0)
response = WebSocketConnection().send({"close": True, "text": text})
self.assert_websocket_upgrade(response, text.encode("ascii"), expect_close=True)
@given(data=http_strategies.binary_payload())
def test_close_with_data(self, data):
assume(len(data) > 0)
response = WebSocketConnection().send({"close": True, "bytes": data})
self.assert_websocket_upgrade(response, data, expect_close=True)
class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
"""
Tests that the WebSocket protocol class correctly generates and parses messages.
"""
def setUp(self):
self.connection = WebSocketConnection()
def test_basic(self):
# Send a simple request to the protocol and get the resulting message off
# of the channel layer.
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"Origin: http://example.com\r\n"
b"\r\n"
)
self.assertEqual(message["path"], "/chat")
self.assertEqual(message["query_string"], b"")
self.assertEqual(
sorted(message["headers"]),
[(b"connection", b"Upgrade"),
(b"host", b"somewhere.com"),
(b"origin", b"http://example.com"),
(b"sec-websocket-key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
(b"sec-websocket-protocol", b"chat, superchat"),
(b"sec-websocket-version", b"13"),
(b"upgrade", b"websocket")]
)
self.assert_valid_websocket_connect_message(message, "/chat")
# Accept the connection
response = self.connection.send({"accept": True})
self.assert_websocket_upgrade(response)
# Send some text
response = self.connection.send({"text": "Hello World!"})
self.assertEqual(response, b"\x81\x0cHello World!")
# Send some bytes
response = self.connection.send({"bytes": b"\xaa\xbb\xcc\xdd"})
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
# Close the connection
response = self.connection.send({"close": True})
self.assertEqual(response, b"\x88\x02\x03\xe8")
def test_connection_with_file_origin_is_accepted(self):
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"Origin: file://\r\n"
b"\r\n"
)
self.assertIn((b"origin", b"file://"), message["headers"])
self.assert_valid_websocket_connect_message(message, "/chat")
# Accept the connection
response = self.connection.send({"accept": True})
self.assert_websocket_upgrade(response)
def test_connection_with_no_origin_is_accepted(self):
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"\r\n"
)
self.assertNotIn(b"origin", [header_tuple[0] for header_tuple in message["headers"]])
self.assert_valid_websocket_connect_message(message, "/chat")
# Accept the connection
response = self.connection.send({"accept": True})
self.assert_websocket_upgrade(response)

View File

@ -73,7 +73,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
"subprotocols": subprotocols, "subprotocols": subprotocols,
"order": 0,
}) })
except: except:
# Exceptions here are not displayed right, just 500. # Exceptions here are not displayed right, just 500.

View File

@ -1,22 +1,26 @@
from urllib import parse
from http.client import HTTPConnection from http.client import HTTPConnection
from urllib import parse
import socket import socket
import struct
import subprocess import subprocess
import time import time
import unittest import unittest
from daphne.test_utils import TestApplication from daphne.test_application import TestApplication
class DaphneTestCase(unittest.TestCase): class DaphneTestingInstance:
""" """
Base class for Daphne integration test cases. Launches an instance of Daphne to test against, with an application
object you can read messages from and feed messages to.
Boots up a copy of Daphne on a test port and sends it a request, and Works as a context manager.
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
""" """
def __init__(self, xff=False):
self.xff = xff
self.host = "127.0.0.1"
def port_in_use(self, port): def port_in_use(self, port):
""" """
Tests if a port is in use on the local machine. Tests if a port is in use on the local machine.
@ -34,39 +38,90 @@ class DaphneTestCase(unittest.TestCase):
finally: finally:
s.close() s.close()
def run_daphne(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): def find_free_port(self):
"""
Finds an unused port to test stuff on
"""
for i in range(11200, 11300):
if not self.port_in_use(i):
return i
raise RuntimeError("Cannot find a free port to test on")
def __enter__(self):
# Clear result storage
TestApplication.delete_setup()
TestApplication.delete_result()
# Find a port to listen on
self.port = self.find_free_port()
daphne_args = ["daphne", "-p", str(self.port), "-v", "0"]
# Optionally enable X-Forwarded-For support.
if self.xff:
daphne_args += ["--proxy-headers"]
# Start up process and make sure it begins listening.
self.process = subprocess.Popen(daphne_args + ["daphne.test_application:TestApplication"])
for _ in range(100):
time.sleep(0.1)
if self.port_in_use(self.port):
return self
# Daphne didn't start up. Sadface.
self.process.terminate()
raise RuntimeError("Daphne never came up.")
def __exit__(self, exc_type, exc_value, traceback):
# Shut down the process
self.process.terminate()
del self.process
def get_received(self):
"""
Returns the scope and messages the test application has received
so far. Note you'll get all messages since scope start, not just any
new ones since the last call.
Also checks for any exceptions in the application. If there are,
raises them.
"""
try:
inner_result = TestApplication.load_result()
except FileNotFoundError:
raise ValueError("No results available yet.")
# Check for exception
if "exception" in inner_result:
raise inner_result["exception"]
return inner_result["scope"], inner_result["messages"]
def add_send_messages(self, messages):
"""
Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these.
"""
TestApplication.save_setup(
response_messages=messages,
)
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
### Plain HTTP helpers
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
""" """
Runs Daphne with the given request callback (given the base URL) Runs Daphne with the given request callback (given the base URL)
and response messages. and response messages.
""" """
# Store setup info with DaphneTestingInstance(xff=xff) as test_app:
TestApplication.clear_storage() # Add the response messages
TestApplication.save_setup( test_app.add_send_messages(responses)
response_messages=responses,
)
# Find a free port
for i in range(11200, 11300):
if not self.port_in_use(i):
port = i
break
else:
raise RuntimeError("Cannot find a free port to test on")
# Launch daphne on that port
daphne_args = ["daphne", "-p", str(port), "-v", "0"]
if xff:
# Optionally enable X-Forwarded-For support.
daphne_args += ["--proxy-headers"]
process = subprocess.Popen(daphne_args + ["daphne.test_utils:TestApplication"])
try:
for _ in range(100):
time.sleep(0.1)
if self.port_in_use(port):
break
else:
raise RuntimeError("Daphne never came up.")
# Send it the request. We have to do this the long way to allow # Send it the request. We have to do this the long way to allow
# duplicate headers. # duplicate headers.
conn = HTTPConnection("127.0.0.1", port, timeout=timeout) conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
# Make sure path is urlquoted and add any params # Make sure path is urlquoted and add any params
path = parse.quote(path) path = parse.quote(path)
if params: if params:
@ -86,29 +141,17 @@ class DaphneTestCase(unittest.TestCase):
response = conn.getresponse() response = conn.getresponse()
except socket.timeout: except socket.timeout:
# See if they left an exception for us to load # See if they left an exception for us to load
try: test_app.get_received()
exception_result = TestApplication.load_result() raise RuntimeError("Daphne timed out handling request, no exception found.")
except OSError: # Return scope, messages, response
raise RuntimeError("Daphne timed out handling request, no result file") return test_app.get_received() + (response, )
else:
if "exception" in exception_result:
raise exception_result["exception"]
else:
raise RuntimeError("Daphne timed out handling request, no exception found: %r" % exception_result)
finally:
# Shut down daphne
process.terminate()
# Load the information
inner_result = TestApplication.load_result()
# Return the inner result and the response
return inner_result, response
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False): def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
""" """
Convenience method for just testing request handling. Convenience method for just testing request handling.
Returns (scope, messages) Returns (scope, messages)
""" """
inner_result, _ = self.run_daphne( scope, messages, _ = self.run_daphne_http(
method=method, method=method,
path=path, path=path,
params=params, params=params,
@ -117,14 +160,14 @@ class DaphneTestCase(unittest.TestCase):
xff=xff, xff=xff,
responses=[{"type": "http.response", "status": 200, "content": b"OK"}], responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
) )
return inner_result["scope"], inner_result["messages"] return scope, messages
def run_daphne_response(self, response_messages): def run_daphne_response(self, response_messages):
""" """
Convenience method for just testing response handling. Convenience method for just testing response handling.
Returns (scope, messages) Returns (scope, messages)
""" """
_, response = self.run_daphne( _, _, response = self.run_daphne_http(
method="GET", method="GET",
path="/", path="/",
params={}, params={},
@ -133,11 +176,119 @@ class DaphneTestCase(unittest.TestCase):
) )
return response return response
### WebSocket helpers
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
You'll need to inject an accept or reject message before this
to let it complete.
"""
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
# Make sure path is urlquoted and add any params
path = parse.quote(path)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
headers.extend([
("Host", "example.com"),
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
("Sec-WebSocket-Version", "13"),
("Origin", "http://example.com")
])
if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
conn.endheaders()
# Read out the response
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Check we got a good response code
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
response_headers = dict(
(n.lower(), v)
for n, v in response.getheaders()
)
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
return conn.sock, response_headers.get("sec-websocket-protocol", None)
def websocket_send_frame(self, sock, value):
"""
Sends a WebSocket text or binary frame. Cannot handle long frames.
"""
# Header and text opcode
if isinstance(value, str):
frame = b"\x81"
value = value.encode("utf8")
else:
frame = b"\x82"
# Length plus masking signal bit
frame += struct.pack("!B", len(value) | 0b10000000)
# Mask badly
frame += b"\0\0\0\0"
# Payload
frame += value
print("sending %r" % frame)
sock.sendall(frame)
def receive_from_socket(self, sock, length, timeout=1):
"""
Receives the given amount of bytes from the socket, or times out.
"""
buf = b""
started = time.time()
while len(buf) < length:
buf += sock.recv(length - len(buf))
time.sleep(0.001)
if time.time() - started > timeout:
raise ValueError("Timed out reading from socket")
return buf
def websocket_receive_frame(self, sock):
"""
Receives a WebSocket frame. Cannot handle long frames.
"""
# Read header byte
# TODO: Proper receive buffer handling
opcode = self.receive_from_socket(sock, 1)
if opcode in [b"\x81", b"\x82"]:
# Read length
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
# Read payload
payload = self.receive_from_socket(sock, length)
if opcode == b"\x81":
payload = payload.decode("utf8")
return payload
else:
raise ValueError("Unknown websocket opcode: %r" % opcode)
### Assertions and test management
def tearDown(self): def tearDown(self):
""" """
Ensures any storage files are cleared. Ensures any storage files are cleared.
""" """
TestApplication.clear_storage() TestApplication.delete_setup()
TestApplication.delete_result()
def assert_is_ip_address(self, address): def assert_is_ip_address(self, address):
""" """
@ -179,85 +330,3 @@ class DaphneTestCase(unittest.TestCase):
self.assertIsInstance(address, str) self.assertIsInstance(address, str)
self.assert_is_ip_address(address) self.assert_is_ip_address(address)
self.assertIsInstance(port, int) self.assertIsInstance(port, int)
# class ASGIWebSocketTestCase(ASGITestCaseBase):
# """
# Test case with helpers for verifying WebSocket channel messages
# """
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
# self.assertIn(body, response)
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
# def assert_websocket_denied(self, response):
# self.assertIn(b"HTTP/1.1 403", response)
# def assert_valid_websocket_connect_message(
# self, channel_message, request_path="/", request_params=None, request_headers=None):
# """
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
# """
# self.assertTrue(channel_message)
# self.assert_presence_of_message_keys(
# channel_message.keys(),
# {"reply_channel", "path", "headers", "order"},
# {"scheme", "query_string", "root_path", "client", "server"})
# # == Assertions about required channel_message fields ==
# self.assert_valid_reply_channel(channel_message["reply_channel"])
# self.assert_valid_path(channel_message["path"], request_path)
# order = channel_message["order"]
# self.assertIsInstance(order, int)
# self.assertEqual(order, 0)
# # Ordering of header names is not important, but the order of values for a header
# # name is. To assert whether that order is kept, we transform the request
# # headers and the channel message headers into a set
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# # get one string per header field with values separated by comma.
# transformed_request_headers = defaultdict(list)
# for name, value in (request_headers or []):
# expected_name = name.lower().strip().encode("ascii")
# expected_value = value.strip().encode("ascii")
# transformed_request_headers[expected_name].append(expected_value)
# final_request_headers = {
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
# }
# # Websockets carry a lot of additional header fields, so instead of verifying that
# # headers look exactly like expected, we just check that the expected header fields
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# # and not tested for.
# assert final_request_headers.issubset(set(channel_message["headers"]))
# # == Assertions about optional channel_message fields ==
# scheme = channel_message.get("scheme")
# if scheme:
# self.assertIsInstance(scheme, six.text_type)
# self.assertIn(scheme, ["ws", "wss"])
# query_string = channel_message.get("query_string")
# if query_string:
# # Assert that query_string is a byte string and still url encoded
# self.assertIsInstance(query_string, six.binary_type)
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
# root_path = channel_message.get("root_path")
# if root_path is not None:
# self.assertIsInstance(root_path, six.text_type)
# client = channel_message.get("client")
# if client is not None:
# self.assert_valid_address_and_port(channel_message["client"])
# server = channel_message.get("server")
# if server is not None:
# self.assert_valid_address_and_port(channel_message["server"])

View File

@ -207,7 +207,7 @@ class TestHTTPRequestSpec(DaphneTestCase):
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
# Note that Daphne returns a list of tuples here, which is fine, because the spec # Note that Daphne returns a list of tuples here, which is fine, because the spec
# asks to treat them interchangeably. # asks to treat them interchangeably.
assert scope["headers"] == [[b"mycustomheader", b"foobar"]] assert [list(x) for x in scope["headers"]] == [[b"mycustomheader", b"foobar"]]
@given(daphne_path=http_strategies.http_path()) @given(daphne_path=http_strategies.http_path())
@settings(max_examples=5, deadline=2000) @settings(max_examples=5, deadline=2000)

239
tests/test_websocket.py Normal file
View File

@ -0,0 +1,239 @@
# coding: utf8
import collections
from urllib import parse
from hypothesis import given, settings
import http_strategies
from http_base import DaphneTestCase, DaphneTestingInstance
class TestWebsocket(DaphneTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
edge cases.
"""
def assert_valid_websocket_scope(
self,
scope,
path="/",
params=None,
headers=None,
scheme=None,
subprotocols=None,
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type", "path", "query_string", "headers"},
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
actual_keys=scope.keys(),
)
# Check that it is the right type
self.assertEqual(scope["type"], "websocket")
# Path
self.assert_valid_path(scope["path"], path)
# Scheme
self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []):
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
# Subprotocols
scope_subprotocols = scope.get("subprotocols", [])
if scope_subprotocols:
assert all(isinstance(x, str) for x in scope_subprotocols)
if subprotocols:
assert sorted(scope_subprotocols) == sorted(subprotocols)
def assert_valid_websocket_connect_message(self, message):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"},
optional_keys=set(),
actual_keys=message.keys(),
)
# Check that it is the right type
self.assertEqual(message["type"], "websocket.connect")
def test_accept(self):
"""
Tests we can open and accept a socket.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
self.websocket_handshake(test_app)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope)
self.assert_valid_websocket_connect_message(messages[0])
def test_reject(self):
"""
Tests we can reject a socket and it won't complete the handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.close",
}
])
with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app)
def test_subprotocols(self):
"""
Tests that we can ask for subprotocols and then select one.
"""
subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
"subprotocol": "proto2",
}
])
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
# Validate the scope and messages we got
assert subprotocol == "proto2"
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
self.assert_valid_websocket_connect_message(messages[0])
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
)
@settings(max_examples=5, deadline=2000)
def test_http_bits(
self,
request_path,
request_params,
request_headers,
):
"""
Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
self.websocket_handshake(
test_app,
path=request_path,
params=request_params,
headers=request_headers,
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(
scope,
path=request_path,
params=request_params,
headers=request_headers,
)
self.assert_valid_websocket_connect_message(messages[0])
def test_text_frames(self):
"""
Tests we can send and receive text frames.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages([
{
"type": "websocket.send",
"text": "here be dragons 🐉",
}
])
# Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
def test_binary_frames(self):
"""
Tests we can send and receive binary frames with things that are very
much not valid UTF-8.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages([
{
"type": "websocket.send",
"bytes": b"here be \xe2 bytes",
}
])
# Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}