mirror of
https://github.com/django/daphne.git
synced 2024-11-21 15:36:33 +03:00
Add websocket tests to make sure everything important is covered.
This commit is contained in:
parent
1ca1c67032
commit
567c27504d
|
@ -1,3 +1,5 @@
|
|||
from concurrent.futures import CancelledError
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
|
@ -17,21 +19,29 @@ class TestApplication:
|
|||
self.messages = []
|
||||
|
||||
async def __call__(self, send, receive):
|
||||
# Load setup info
|
||||
setup = self.load_setup()
|
||||
# Receive input and send output
|
||||
logging.debug("test app coroutine alive")
|
||||
try:
|
||||
for _ in range(setup["receive_messages"]):
|
||||
while True:
|
||||
# Receive a message and save it into the result store
|
||||
self.messages.append(await receive())
|
||||
for message in setup["response_messages"]:
|
||||
await send(message)
|
||||
logging.debug("test app received %r", self.messages[-1])
|
||||
self.save_result(self.scope, self.messages)
|
||||
# See if there are any messages to send back
|
||||
setup = self.load_setup()
|
||||
self.delete_setup()
|
||||
for message in setup["response_messages"]:
|
||||
await send(message)
|
||||
logging.debug("test app sent %r", message)
|
||||
except Exception as e:
|
||||
self.save_exception(e)
|
||||
else:
|
||||
self.save_result()
|
||||
if isinstance(e, CancelledError):
|
||||
# Don't catch task-cancelled errors!
|
||||
raise
|
||||
else:
|
||||
self.save_exception(e)
|
||||
|
||||
@classmethod
|
||||
def save_setup(cls, response_messages, receive_messages=1):
|
||||
def save_setup(cls, response_messages):
|
||||
"""
|
||||
Stores setup information.
|
||||
"""
|
||||
|
@ -39,7 +49,6 @@ class TestApplication:
|
|||
pickle.dump(
|
||||
{
|
||||
"response_messages": response_messages,
|
||||
"receive_messages": receive_messages,
|
||||
},
|
||||
fh,
|
||||
)
|
||||
|
@ -49,29 +58,34 @@ class TestApplication:
|
|||
"""
|
||||
Returns setup details.
|
||||
"""
|
||||
with open(cls.setup_storage, "rb") as fh:
|
||||
return pickle.load(fh)
|
||||
try:
|
||||
with open(cls.setup_storage, "rb") as fh:
|
||||
return pickle.load(fh)
|
||||
except FileNotFoundError:
|
||||
return {"response_messages": []}
|
||||
|
||||
def save_result(self):
|
||||
@classmethod
|
||||
def save_result(cls, scope, messages):
|
||||
"""
|
||||
Saves details of what happened to the result storage.
|
||||
We could use pickle here, but that seems wrong, still, somehow.
|
||||
"""
|
||||
with open(self.result_storage, "wb") as fh:
|
||||
with open(cls.result_storage, "wb") as fh:
|
||||
pickle.dump(
|
||||
{
|
||||
"scope": self.scope,
|
||||
"messages": self.messages,
|
||||
"scope": scope,
|
||||
"messages": messages,
|
||||
},
|
||||
fh,
|
||||
)
|
||||
|
||||
def save_exception(self, exception):
|
||||
@classmethod
|
||||
def save_exception(cls, exception):
|
||||
"""
|
||||
Saves details of what happened to the result storage.
|
||||
We could use pickle here, but that seems wrong, still, somehow.
|
||||
"""
|
||||
with open(self.result_storage, "wb") as fh:
|
||||
with open(cls.result_storage, "wb") as fh:
|
||||
pickle.dump(
|
||||
{
|
||||
"exception": exception,
|
||||
|
@ -88,14 +102,20 @@ class TestApplication:
|
|||
return pickle.load(fh)
|
||||
|
||||
@classmethod
|
||||
def clear_storage(cls):
|
||||
def delete_setup(cls):
|
||||
"""
|
||||
Clears storage files.
|
||||
Clears setup storage files.
|
||||
"""
|
||||
try:
|
||||
os.unlink(cls.setup_storage)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def delete_result(cls):
|
||||
"""
|
||||
Clears result storage files.
|
||||
"""
|
||||
try:
|
||||
os.unlink(cls.result_storage)
|
||||
except OSError:
|
|
@ -1,128 +0,0 @@
|
|||
from __future__ import unicode_literals
|
||||
import six
|
||||
from six.moves.urllib import parse
|
||||
|
||||
from asgiref.inmemory import ChannelLayer
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
from daphne.http_protocol import HTTPFactory
|
||||
|
||||
|
||||
def message_for_request(method, path, params=None, headers=None, body=None):
|
||||
"""
|
||||
Constructs a HTTP request according to the given parameters, runs
|
||||
that through daphne and returns the emitted channel message.
|
||||
"""
|
||||
request = _build_request(method, path, params, headers, body)
|
||||
message, factory, transport = _run_through_daphne(request, "http.request")
|
||||
return message
|
||||
|
||||
|
||||
def response_for_message(message):
|
||||
"""
|
||||
Returns the raw HTTP response that Daphne constructs when sending a reply
|
||||
to a HTTP request.
|
||||
|
||||
The current approach actually first builds a HTTP request (similar to
|
||||
message_for_request) because we need a valid reply channel. I'm sure
|
||||
this can be streamlined, but it works for now.
|
||||
"""
|
||||
request = _build_request("GET", "/")
|
||||
request_message, factory, transport = _run_through_daphne(request, "http.request")
|
||||
factory.dispatch_reply(request_message["reply_channel"], message)
|
||||
return transport.value()
|
||||
|
||||
|
||||
def _build_request(method, path, params=None, headers=None, body=None):
|
||||
"""
|
||||
Takes request parameters and returns a byte string of a valid HTTP/1.1 request.
|
||||
|
||||
We really shouldn't manually build a HTTP request, and instead try to capture
|
||||
what e.g. urllib or requests would do. But that is non-trivial, so meanwhile
|
||||
we hope that our request building doesn't mask any errors.
|
||||
|
||||
This code is messy, because urllib behaves rather different between Python 2
|
||||
and 3. Readability is further obstructed by the fact that Python 3.4 doesn't
|
||||
support % formatting for bytes, so we need to concat everything.
|
||||
If we run into more issues with this, the python-future library has a backport
|
||||
of Python 3's urllib.
|
||||
|
||||
:param method: ASCII string of HTTP method.
|
||||
:param path: unicode string of URL path.
|
||||
:param params: List of two-tuples of bytestrings, ready for consumption for
|
||||
urlencode. Encode to utf8 if necessary.
|
||||
:param headers: List of two-tuples ASCII strings of HTTP header, value.
|
||||
:param body: ASCII string of request body.
|
||||
|
||||
ASCII string is short for a unicode string containing only ASCII characters,
|
||||
or a byte string with ASCII encoding.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = []
|
||||
else:
|
||||
headers = headers[:]
|
||||
|
||||
if six.PY3:
|
||||
quoted_path = parse.quote(path)
|
||||
if params:
|
||||
quoted_path += "?" + parse.urlencode(params)
|
||||
quoted_path = quoted_path.encode("ascii")
|
||||
else:
|
||||
quoted_path = parse.quote(path.encode("utf8"))
|
||||
if params:
|
||||
quoted_path += b"?" + parse.urlencode(params)
|
||||
|
||||
request = method.encode("ascii") + b" " + quoted_path + b" HTTP/1.1\r\n"
|
||||
for name, value in headers:
|
||||
request += header_line(name, value)
|
||||
|
||||
request += b"\r\n"
|
||||
|
||||
if body:
|
||||
request += body.encode("ascii")
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def build_websocket_upgrade(path, params, headers):
|
||||
ws_headers = [
|
||||
("Host", "somewhere.com"),
|
||||
("Upgrade", "websocket"),
|
||||
("Connection", "Upgrade"),
|
||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||
("Sec-WebSocket-Protocol", "chat, superchat"),
|
||||
("Sec-WebSocket-Version", "13"),
|
||||
("Origin", "http://example.com")
|
||||
]
|
||||
return _build_request("GET", path, params, headers=headers + ws_headers, body=None)
|
||||
|
||||
|
||||
def header_line(name, value):
|
||||
"""
|
||||
Given a header name and value, returns the line to use in a HTTP request or response.
|
||||
"""
|
||||
return name.encode("ascii") + b": " + value.encode("ascii") + b"\r\n"
|
||||
|
||||
|
||||
def _run_through_daphne(request, channel_name):
|
||||
"""
|
||||
Returns Daphne's channel message for a given request.
|
||||
|
||||
This helper requires a fair bit of scaffolding and can certainly be improved,
|
||||
but it works for now.
|
||||
"""
|
||||
channel_layer = ChannelLayer()
|
||||
factory = HTTPFactory(channel_layer, send_channel="test!")
|
||||
proto = factory.buildProtocol(("127.0.0.1", 0))
|
||||
transport = proto_helpers.StringTransport()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(request)
|
||||
_, message = channel_layer.receive([channel_name])
|
||||
return message, factory, transport
|
||||
|
||||
|
||||
def content_length_header(body):
|
||||
"""
|
||||
Returns an appropriate Content-Length HTTP header for a given body.
|
||||
"""
|
||||
return "Content-Length", six.text_type(len(body))
|
|
@ -1,246 +0,0 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from hypothesis import assume, given, strategies, settings
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
from asgiref.inmemory import ChannelLayer
|
||||
from daphne.http_protocol import HTTPFactory
|
||||
from daphne.tests import http_strategies, testcases, factories
|
||||
|
||||
|
||||
class WebSocketConnection(object):
|
||||
"""
|
||||
Helper class that makes it easier to test Dahpne's WebSocket support.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_message = None
|
||||
|
||||
self.channel_layer = ChannelLayer()
|
||||
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
||||
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
||||
self.transport = proto_helpers.StringTransport()
|
||||
self.proto.makeConnection(self.transport)
|
||||
|
||||
def receive(self, request):
|
||||
"""
|
||||
Low-level method to let Daphne handle HTTP/WebSocket data
|
||||
"""
|
||||
self.proto.dataReceived(request)
|
||||
_, self.last_message = self.channel_layer.receive(["websocket.connect"])
|
||||
return self.last_message
|
||||
|
||||
def send(self, content):
|
||||
"""
|
||||
Method to respond with a channel message
|
||||
"""
|
||||
if self.last_message is None:
|
||||
# Auto-connect for convenience.
|
||||
self.connect()
|
||||
self.factory.dispatch_reply(self.last_message["reply_channel"], content)
|
||||
response = self.transport.value()
|
||||
self.transport.clear()
|
||||
return response
|
||||
|
||||
def connect(self, path="/", params=None, headers=None):
|
||||
"""
|
||||
High-level method to perform the WebSocket handshake
|
||||
"""
|
||||
request = factories.build_websocket_upgrade(path, params, headers or [])
|
||||
message = self.receive(request)
|
||||
return message
|
||||
|
||||
|
||||
class TestHandshake(testcases.ASGIWebSocketTestCase):
|
||||
"""
|
||||
Tests for the WebSocket handshake
|
||||
"""
|
||||
|
||||
def test_minimal(self):
|
||||
message = WebSocketConnection().connect()
|
||||
self.assert_valid_websocket_connect_message(message)
|
||||
|
||||
@given(
|
||||
path=http_strategies.http_path(),
|
||||
params=http_strategies.query_params(),
|
||||
headers=http_strategies.headers(),
|
||||
)
|
||||
@settings(perform_health_check=False)
|
||||
def test_connection(self, path, params, headers):
|
||||
message = WebSocketConnection().connect(path, params, headers)
|
||||
self.assert_valid_websocket_connect_message(message, path, params, headers)
|
||||
|
||||
|
||||
class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
|
||||
"""
|
||||
Tests that, essentially, try to translate the send/close/accept section of the spec into code.
|
||||
"""
|
||||
|
||||
def test_empty_accept(self):
|
||||
response = WebSocketConnection().send({"accept": True})
|
||||
self.assert_websocket_upgrade(response)
|
||||
|
||||
@given(text=http_strategies.http_body())
|
||||
def test_accept_and_text(self, text):
|
||||
response = WebSocketConnection().send({"accept": True, "text": text})
|
||||
self.assert_websocket_upgrade(response, text.encode("ascii"))
|
||||
|
||||
@given(data=http_strategies.binary_payload())
|
||||
def test_accept_and_bytes(self, data):
|
||||
response = WebSocketConnection().send({"accept": True, "bytes": data})
|
||||
self.assert_websocket_upgrade(response, data)
|
||||
|
||||
def test_accept_false(self):
|
||||
response = WebSocketConnection().send({"accept": False})
|
||||
self.assert_websocket_denied(response)
|
||||
|
||||
def test_accept_false_with_text(self):
|
||||
"""
|
||||
Tests that even if text is given, the connection is denied.
|
||||
|
||||
We can't easily use Hypothesis to generate data for this test because it's
|
||||
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
||||
"""
|
||||
text = "foobar"
|
||||
response = WebSocketConnection().send({"accept": False, "text": text})
|
||||
self.assert_websocket_denied(response)
|
||||
self.assertNotIn(text.encode("ascii"), response)
|
||||
|
||||
def test_accept_false_with_bytes(self):
|
||||
"""
|
||||
Tests that even if data is given, the connection is denied.
|
||||
|
||||
We can't easily use Hypothesis to generate data for this test because it's
|
||||
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
||||
"""
|
||||
data = b"foobar"
|
||||
response = WebSocketConnection().send({"accept": False, "bytes": data})
|
||||
self.assert_websocket_denied(response)
|
||||
self.assertNotIn(data, response)
|
||||
|
||||
@given(text=http_strategies.http_body())
|
||||
def test_just_text(self, text):
|
||||
assume(len(text) > 0)
|
||||
# If content is sent, accept=True is implied.
|
||||
response = WebSocketConnection().send({"text": text})
|
||||
self.assert_websocket_upgrade(response, text.encode("ascii"))
|
||||
|
||||
@given(data=http_strategies.binary_payload())
|
||||
def test_just_bytes(self, data):
|
||||
assume(len(data) > 0)
|
||||
# If content is sent, accept=True is implied.
|
||||
response = WebSocketConnection().send({"bytes": data})
|
||||
self.assert_websocket_upgrade(response, data)
|
||||
|
||||
def test_close_boolean(self):
|
||||
response = WebSocketConnection().send({"close": True})
|
||||
self.assert_websocket_denied(response)
|
||||
|
||||
@given(number=strategies.integers(min_value=1))
|
||||
def test_close_integer(self, number):
|
||||
response = WebSocketConnection().send({"close": number})
|
||||
self.assert_websocket_denied(response)
|
||||
|
||||
@given(text=http_strategies.http_body())
|
||||
def test_close_with_text(self, text):
|
||||
assume(len(text) > 0)
|
||||
response = WebSocketConnection().send({"close": True, "text": text})
|
||||
self.assert_websocket_upgrade(response, text.encode("ascii"), expect_close=True)
|
||||
|
||||
@given(data=http_strategies.binary_payload())
|
||||
def test_close_with_data(self, data):
|
||||
assume(len(data) > 0)
|
||||
response = WebSocketConnection().send({"close": True, "bytes": data})
|
||||
self.assert_websocket_upgrade(response, data, expect_close=True)
|
||||
|
||||
|
||||
class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
|
||||
"""
|
||||
Tests that the WebSocket protocol class correctly generates and parses messages.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.connection = WebSocketConnection()
|
||||
|
||||
def test_basic(self):
|
||||
# Send a simple request to the protocol and get the resulting message off
|
||||
# of the channel layer.
|
||||
message = self.connection.receive(
|
||||
b"GET /chat HTTP/1.1\r\n"
|
||||
b"Host: somewhere.com\r\n"
|
||||
b"Upgrade: websocket\r\n"
|
||||
b"Connection: Upgrade\r\n"
|
||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
||||
b"Sec-WebSocket-Version: 13\r\n"
|
||||
b"Origin: http://example.com\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
self.assertEqual(message["path"], "/chat")
|
||||
self.assertEqual(message["query_string"], b"")
|
||||
self.assertEqual(
|
||||
sorted(message["headers"]),
|
||||
[(b"connection", b"Upgrade"),
|
||||
(b"host", b"somewhere.com"),
|
||||
(b"origin", b"http://example.com"),
|
||||
(b"sec-websocket-key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||
(b"sec-websocket-protocol", b"chat, superchat"),
|
||||
(b"sec-websocket-version", b"13"),
|
||||
(b"upgrade", b"websocket")]
|
||||
)
|
||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
||||
|
||||
# Accept the connection
|
||||
response = self.connection.send({"accept": True})
|
||||
self.assert_websocket_upgrade(response)
|
||||
|
||||
# Send some text
|
||||
response = self.connection.send({"text": "Hello World!"})
|
||||
self.assertEqual(response, b"\x81\x0cHello World!")
|
||||
|
||||
# Send some bytes
|
||||
response = self.connection.send({"bytes": b"\xaa\xbb\xcc\xdd"})
|
||||
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
|
||||
|
||||
# Close the connection
|
||||
response = self.connection.send({"close": True})
|
||||
self.assertEqual(response, b"\x88\x02\x03\xe8")
|
||||
|
||||
def test_connection_with_file_origin_is_accepted(self):
|
||||
message = self.connection.receive(
|
||||
b"GET /chat HTTP/1.1\r\n"
|
||||
b"Host: somewhere.com\r\n"
|
||||
b"Upgrade: websocket\r\n"
|
||||
b"Connection: Upgrade\r\n"
|
||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
||||
b"Sec-WebSocket-Version: 13\r\n"
|
||||
b"Origin: file://\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
self.assertIn((b"origin", b"file://"), message["headers"])
|
||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
||||
|
||||
# Accept the connection
|
||||
response = self.connection.send({"accept": True})
|
||||
self.assert_websocket_upgrade(response)
|
||||
|
||||
def test_connection_with_no_origin_is_accepted(self):
|
||||
message = self.connection.receive(
|
||||
b"GET /chat HTTP/1.1\r\n"
|
||||
b"Host: somewhere.com\r\n"
|
||||
b"Upgrade: websocket\r\n"
|
||||
b"Connection: Upgrade\r\n"
|
||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
||||
b"Sec-WebSocket-Version: 13\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
self.assertNotIn(b"origin", [header_tuple[0] for header_tuple in message["headers"]])
|
||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
||||
|
||||
# Accept the connection
|
||||
response = self.connection.send({"accept": True})
|
||||
self.assert_websocket_upgrade(response)
|
|
@ -73,7 +73,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
"order": 0,
|
||||
})
|
||||
except:
|
||||
# Exceptions here are not displayed right, just 500.
|
||||
|
|
|
@ -1,22 +1,26 @@
|
|||
from urllib import parse
|
||||
from http.client import HTTPConnection
|
||||
from urllib import parse
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from daphne.test_utils import TestApplication
|
||||
from daphne.test_application import TestApplication
|
||||
|
||||
|
||||
class DaphneTestCase(unittest.TestCase):
|
||||
class DaphneTestingInstance:
|
||||
"""
|
||||
Base class for Daphne integration test cases.
|
||||
Launches an instance of Daphne to test against, with an application
|
||||
object you can read messages from and feed messages to.
|
||||
|
||||
Boots up a copy of Daphne on a test port and sends it a request, and
|
||||
retrieves the response. Uses a custom ASGI application and temporary files
|
||||
to store/retrieve the request/response messages.
|
||||
Works as a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, xff=False):
|
||||
self.xff = xff
|
||||
self.host = "127.0.0.1"
|
||||
|
||||
def port_in_use(self, port):
|
||||
"""
|
||||
Tests if a port is in use on the local machine.
|
||||
|
@ -34,39 +38,90 @@ class DaphneTestCase(unittest.TestCase):
|
|||
finally:
|
||||
s.close()
|
||||
|
||||
def run_daphne(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
||||
def find_free_port(self):
|
||||
"""
|
||||
Finds an unused port to test stuff on
|
||||
"""
|
||||
for i in range(11200, 11300):
|
||||
if not self.port_in_use(i):
|
||||
return i
|
||||
raise RuntimeError("Cannot find a free port to test on")
|
||||
|
||||
def __enter__(self):
|
||||
# Clear result storage
|
||||
TestApplication.delete_setup()
|
||||
TestApplication.delete_result()
|
||||
# Find a port to listen on
|
||||
self.port = self.find_free_port()
|
||||
daphne_args = ["daphne", "-p", str(self.port), "-v", "0"]
|
||||
# Optionally enable X-Forwarded-For support.
|
||||
if self.xff:
|
||||
daphne_args += ["--proxy-headers"]
|
||||
# Start up process and make sure it begins listening.
|
||||
self.process = subprocess.Popen(daphne_args + ["daphne.test_application:TestApplication"])
|
||||
for _ in range(100):
|
||||
time.sleep(0.1)
|
||||
if self.port_in_use(self.port):
|
||||
return self
|
||||
# Daphne didn't start up. Sadface.
|
||||
self.process.terminate()
|
||||
raise RuntimeError("Daphne never came up.")
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# Shut down the process
|
||||
self.process.terminate()
|
||||
del self.process
|
||||
|
||||
def get_received(self):
|
||||
"""
|
||||
Returns the scope and messages the test application has received
|
||||
so far. Note you'll get all messages since scope start, not just any
|
||||
new ones since the last call.
|
||||
|
||||
Also checks for any exceptions in the application. If there are,
|
||||
raises them.
|
||||
"""
|
||||
try:
|
||||
inner_result = TestApplication.load_result()
|
||||
except FileNotFoundError:
|
||||
raise ValueError("No results available yet.")
|
||||
# Check for exception
|
||||
if "exception" in inner_result:
|
||||
raise inner_result["exception"]
|
||||
return inner_result["scope"], inner_result["messages"]
|
||||
|
||||
def add_send_messages(self, messages):
|
||||
"""
|
||||
Adds messages for the application to send back.
|
||||
The next time it receives an incoming message, it will reply with these.
|
||||
"""
|
||||
TestApplication.save_setup(
|
||||
response_messages=messages,
|
||||
)
|
||||
|
||||
|
||||
class DaphneTestCase(unittest.TestCase):
|
||||
"""
|
||||
Base class for Daphne integration test cases.
|
||||
|
||||
Boots up a copy of Daphne on a test port and sends it a request, and
|
||||
retrieves the response. Uses a custom ASGI application and temporary files
|
||||
to store/retrieve the request/response messages.
|
||||
"""
|
||||
|
||||
### Plain HTTP helpers
|
||||
|
||||
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
||||
"""
|
||||
Runs Daphne with the given request callback (given the base URL)
|
||||
and response messages.
|
||||
"""
|
||||
# Store setup info
|
||||
TestApplication.clear_storage()
|
||||
TestApplication.save_setup(
|
||||
response_messages=responses,
|
||||
)
|
||||
# Find a free port
|
||||
for i in range(11200, 11300):
|
||||
if not self.port_in_use(i):
|
||||
port = i
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Cannot find a free port to test on")
|
||||
# Launch daphne on that port
|
||||
daphne_args = ["daphne", "-p", str(port), "-v", "0"]
|
||||
if xff:
|
||||
# Optionally enable X-Forwarded-For support.
|
||||
daphne_args += ["--proxy-headers"]
|
||||
process = subprocess.Popen(daphne_args + ["daphne.test_utils:TestApplication"])
|
||||
try:
|
||||
for _ in range(100):
|
||||
time.sleep(0.1)
|
||||
if self.port_in_use(port):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Daphne never came up.")
|
||||
with DaphneTestingInstance(xff=xff) as test_app:
|
||||
# Add the response messages
|
||||
test_app.add_send_messages(responses)
|
||||
# Send it the request. We have to do this the long way to allow
|
||||
# duplicate headers.
|
||||
conn = HTTPConnection("127.0.0.1", port, timeout=timeout)
|
||||
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
||||
# Make sure path is urlquoted and add any params
|
||||
path = parse.quote(path)
|
||||
if params:
|
||||
|
@ -86,29 +141,17 @@ class DaphneTestCase(unittest.TestCase):
|
|||
response = conn.getresponse()
|
||||
except socket.timeout:
|
||||
# See if they left an exception for us to load
|
||||
try:
|
||||
exception_result = TestApplication.load_result()
|
||||
except OSError:
|
||||
raise RuntimeError("Daphne timed out handling request, no result file")
|
||||
else:
|
||||
if "exception" in exception_result:
|
||||
raise exception_result["exception"]
|
||||
else:
|
||||
raise RuntimeError("Daphne timed out handling request, no exception found: %r" % exception_result)
|
||||
finally:
|
||||
# Shut down daphne
|
||||
process.terminate()
|
||||
# Load the information
|
||||
inner_result = TestApplication.load_result()
|
||||
# Return the inner result and the response
|
||||
return inner_result, response
|
||||
test_app.get_received()
|
||||
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
||||
# Return scope, messages, response
|
||||
return test_app.get_received() + (response, )
|
||||
|
||||
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
|
||||
"""
|
||||
Convenience method for just testing request handling.
|
||||
Returns (scope, messages)
|
||||
"""
|
||||
inner_result, _ = self.run_daphne(
|
||||
scope, messages, _ = self.run_daphne_http(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
|
@ -117,14 +160,14 @@ class DaphneTestCase(unittest.TestCase):
|
|||
xff=xff,
|
||||
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
|
||||
)
|
||||
return inner_result["scope"], inner_result["messages"]
|
||||
return scope, messages
|
||||
|
||||
def run_daphne_response(self, response_messages):
|
||||
"""
|
||||
Convenience method for just testing response handling.
|
||||
Returns (scope, messages)
|
||||
"""
|
||||
_, response = self.run_daphne(
|
||||
_, _, response = self.run_daphne_http(
|
||||
method="GET",
|
||||
path="/",
|
||||
params={},
|
||||
|
@ -133,11 +176,119 @@ class DaphneTestCase(unittest.TestCase):
|
|||
)
|
||||
return response
|
||||
|
||||
### WebSocket helpers
|
||||
|
||||
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
|
||||
"""
|
||||
Runs a WebSocket handshake negotiation and returns the raw socket
|
||||
object & the selected subprotocol.
|
||||
|
||||
You'll need to inject an accept or reject message before this
|
||||
to let it complete.
|
||||
"""
|
||||
# Send it the request. We have to do this the long way to allow
|
||||
# duplicate headers.
|
||||
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
||||
# Make sure path is urlquoted and add any params
|
||||
path = parse.quote(path)
|
||||
if params:
|
||||
path += "?" + parse.urlencode(params, doseq=True)
|
||||
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
|
||||
# Do WebSocket handshake headers + any other headers
|
||||
if headers is None:
|
||||
headers = []
|
||||
headers.extend([
|
||||
("Host", "example.com"),
|
||||
("Upgrade", "websocket"),
|
||||
("Connection", "Upgrade"),
|
||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||
("Sec-WebSocket-Version", "13"),
|
||||
("Origin", "http://example.com")
|
||||
])
|
||||
if subprotocols:
|
||||
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
||||
if headers:
|
||||
for header_name, header_value in headers:
|
||||
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
|
||||
conn.endheaders()
|
||||
# Read out the response
|
||||
try:
|
||||
response = conn.getresponse()
|
||||
except socket.timeout:
|
||||
# See if they left an exception for us to load
|
||||
test_app.get_received()
|
||||
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
||||
# Check we got a good response code
|
||||
if response.status != 101:
|
||||
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
||||
# Prepare headers for subprotocol searching
|
||||
response_headers = dict(
|
||||
(n.lower(), v)
|
||||
for n, v in response.getheaders()
|
||||
)
|
||||
response.read()
|
||||
assert not response.closed
|
||||
# Return the raw socket and any subprotocol
|
||||
return conn.sock, response_headers.get("sec-websocket-protocol", None)
|
||||
|
||||
def websocket_send_frame(self, sock, value):
|
||||
"""
|
||||
Sends a WebSocket text or binary frame. Cannot handle long frames.
|
||||
"""
|
||||
# Header and text opcode
|
||||
if isinstance(value, str):
|
||||
frame = b"\x81"
|
||||
value = value.encode("utf8")
|
||||
else:
|
||||
frame = b"\x82"
|
||||
# Length plus masking signal bit
|
||||
frame += struct.pack("!B", len(value) | 0b10000000)
|
||||
# Mask badly
|
||||
frame += b"\0\0\0\0"
|
||||
# Payload
|
||||
frame += value
|
||||
print("sending %r" % frame)
|
||||
sock.sendall(frame)
|
||||
|
||||
def receive_from_socket(self, sock, length, timeout=1):
|
||||
"""
|
||||
Receives the given amount of bytes from the socket, or times out.
|
||||
"""
|
||||
buf = b""
|
||||
started = time.time()
|
||||
while len(buf) < length:
|
||||
buf += sock.recv(length - len(buf))
|
||||
time.sleep(0.001)
|
||||
if time.time() - started > timeout:
|
||||
raise ValueError("Timed out reading from socket")
|
||||
return buf
|
||||
|
||||
def websocket_receive_frame(self, sock):
|
||||
"""
|
||||
Receives a WebSocket frame. Cannot handle long frames.
|
||||
"""
|
||||
# Read header byte
|
||||
# TODO: Proper receive buffer handling
|
||||
opcode = self.receive_from_socket(sock, 1)
|
||||
if opcode in [b"\x81", b"\x82"]:
|
||||
# Read length
|
||||
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
|
||||
# Read payload
|
||||
payload = self.receive_from_socket(sock, length)
|
||||
if opcode == b"\x81":
|
||||
payload = payload.decode("utf8")
|
||||
return payload
|
||||
else:
|
||||
raise ValueError("Unknown websocket opcode: %r" % opcode)
|
||||
|
||||
### Assertions and test management
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Ensures any storage files are cleared.
|
||||
"""
|
||||
TestApplication.clear_storage()
|
||||
TestApplication.delete_setup()
|
||||
TestApplication.delete_result()
|
||||
|
||||
def assert_is_ip_address(self, address):
|
||||
"""
|
||||
|
@ -179,85 +330,3 @@ class DaphneTestCase(unittest.TestCase):
|
|||
self.assertIsInstance(address, str)
|
||||
self.assert_is_ip_address(address)
|
||||
self.assertIsInstance(port, int)
|
||||
|
||||
|
||||
|
||||
# class ASGIWebSocketTestCase(ASGITestCaseBase):
|
||||
# """
|
||||
# Test case with helpers for verifying WebSocket channel messages
|
||||
# """
|
||||
|
||||
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
|
||||
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
||||
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
||||
# self.assertIn(body, response)
|
||||
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
||||
|
||||
# def assert_websocket_denied(self, response):
|
||||
# self.assertIn(b"HTTP/1.1 403", response)
|
||||
|
||||
# def assert_valid_websocket_connect_message(
|
||||
# self, channel_message, request_path="/", request_params=None, request_headers=None):
|
||||
# """
|
||||
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
||||
# """
|
||||
|
||||
# self.assertTrue(channel_message)
|
||||
|
||||
# self.assert_presence_of_message_keys(
|
||||
# channel_message.keys(),
|
||||
# {"reply_channel", "path", "headers", "order"},
|
||||
# {"scheme", "query_string", "root_path", "client", "server"})
|
||||
|
||||
# # == Assertions about required channel_message fields ==
|
||||
# self.assert_valid_reply_channel(channel_message["reply_channel"])
|
||||
# self.assert_valid_path(channel_message["path"], request_path)
|
||||
|
||||
# order = channel_message["order"]
|
||||
# self.assertIsInstance(order, int)
|
||||
# self.assertEqual(order, 0)
|
||||
|
||||
# # Ordering of header names is not important, but the order of values for a header
|
||||
# # name is. To assert whether that order is kept, we transform the request
|
||||
# # headers and the channel message headers into a set
|
||||
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
||||
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
||||
# # get one string per header field with values separated by comma.
|
||||
# transformed_request_headers = defaultdict(list)
|
||||
# for name, value in (request_headers or []):
|
||||
# expected_name = name.lower().strip().encode("ascii")
|
||||
# expected_value = value.strip().encode("ascii")
|
||||
# transformed_request_headers[expected_name].append(expected_value)
|
||||
# final_request_headers = {
|
||||
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
|
||||
# }
|
||||
|
||||
# # Websockets carry a lot of additional header fields, so instead of verifying that
|
||||
# # headers look exactly like expected, we just check that the expected header fields
|
||||
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
||||
# # and not tested for.
|
||||
# assert final_request_headers.issubset(set(channel_message["headers"]))
|
||||
|
||||
# # == Assertions about optional channel_message fields ==
|
||||
# scheme = channel_message.get("scheme")
|
||||
# if scheme:
|
||||
# self.assertIsInstance(scheme, six.text_type)
|
||||
# self.assertIn(scheme, ["ws", "wss"])
|
||||
|
||||
# query_string = channel_message.get("query_string")
|
||||
# if query_string:
|
||||
# # Assert that query_string is a byte string and still url encoded
|
||||
# self.assertIsInstance(query_string, six.binary_type)
|
||||
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
||||
|
||||
# root_path = channel_message.get("root_path")
|
||||
# if root_path is not None:
|
||||
# self.assertIsInstance(root_path, six.text_type)
|
||||
|
||||
# client = channel_message.get("client")
|
||||
# if client is not None:
|
||||
# self.assert_valid_address_and_port(channel_message["client"])
|
||||
|
||||
# server = channel_message.get("server")
|
||||
# if server is not None:
|
||||
# self.assert_valid_address_and_port(channel_message["server"])
|
||||
|
|
|
@ -207,7 +207,7 @@ class TestHTTPRequestSpec(DaphneTestCase):
|
|||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
# Note that Daphne returns a list of tuples here, which is fine, because the spec
|
||||
# asks to treat them interchangeably.
|
||||
assert scope["headers"] == [[b"mycustomheader", b"foobar"]]
|
||||
assert [list(x) for x in scope["headers"]] == [[b"mycustomheader", b"foobar"]]
|
||||
|
||||
@given(daphne_path=http_strategies.http_path())
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
|
|
239
tests/test_websocket.py
Normal file
239
tests/test_websocket.py
Normal file
|
@ -0,0 +1,239 @@
|
|||
# coding: utf8
|
||||
|
||||
import collections
|
||||
from urllib import parse
|
||||
|
||||
from hypothesis import given, settings
|
||||
|
||||
import http_strategies
|
||||
from http_base import DaphneTestCase, DaphneTestingInstance
|
||||
|
||||
|
||||
class TestWebsocket(DaphneTestCase):
|
||||
"""
|
||||
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
||||
The heavy lifting is done by the assert_valid_http_request_message function,
|
||||
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
|
||||
edge cases.
|
||||
"""
|
||||
|
||||
def assert_valid_websocket_scope(
|
||||
self,
|
||||
scope,
|
||||
path="/",
|
||||
params=None,
|
||||
headers=None,
|
||||
scheme=None,
|
||||
subprotocols=None,
|
||||
):
|
||||
"""
|
||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||
and some urlencoding things.
|
||||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type", "path", "query_string", "headers"},
|
||||
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
|
||||
actual_keys=scope.keys(),
|
||||
)
|
||||
# Check that it is the right type
|
||||
self.assertEqual(scope["type"], "websocket")
|
||||
# Path
|
||||
self.assert_valid_path(scope["path"], path)
|
||||
# Scheme
|
||||
self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
|
||||
if scheme:
|
||||
self.assertEqual(scheme, scope["scheme"])
|
||||
# Query string (byte string and still url encoded)
|
||||
query_string = scope["query_string"]
|
||||
self.assertIsInstance(query_string, bytes)
|
||||
if params:
|
||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform both the request
|
||||
# headers and the channel message headers into a dictionary
|
||||
# {name: [value1, value2, ...]} and check if they're equal.
|
||||
transformed_scope_headers = collections.defaultdict(list)
|
||||
for name, value in scope["headers"]:
|
||||
transformed_scope_headers[name].append(value)
|
||||
transformed_request_headers = collections.defaultdict(list)
|
||||
for name, value in (headers or []):
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
transformed_request_headers[expected_name].append(expected_value)
|
||||
for name, value in transformed_request_headers.items():
|
||||
self.assertIn(name, transformed_scope_headers)
|
||||
self.assertEqual(value, transformed_scope_headers[name])
|
||||
# Root path
|
||||
self.assertIsInstance(scope.get("root_path", ""), str)
|
||||
# Client and server addresses
|
||||
client = scope.get("client")
|
||||
if client is not None:
|
||||
self.assert_valid_address_and_port(client)
|
||||
server = scope.get("server")
|
||||
if server is not None:
|
||||
self.assert_valid_address_and_port(server)
|
||||
# Subprotocols
|
||||
scope_subprotocols = scope.get("subprotocols", [])
|
||||
if scope_subprotocols:
|
||||
assert all(isinstance(x, str) for x in scope_subprotocols)
|
||||
if subprotocols:
|
||||
assert sorted(scope_subprotocols) == sorted(subprotocols)
|
||||
|
||||
def assert_valid_websocket_connect_message(self, message):
|
||||
"""
|
||||
Asserts that a message is a valid http.request message
|
||||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type"},
|
||||
optional_keys=set(),
|
||||
actual_keys=message.keys(),
|
||||
)
|
||||
# Check that it is the right type
|
||||
self.assertEqual(message["type"], "websocket.connect")
|
||||
|
||||
def test_accept(self):
|
||||
"""
|
||||
Tests we can open and accept a socket.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
self.websocket_handshake(test_app)
|
||||
# Validate the scope and messages we got
|
||||
scope, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_scope(scope)
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
|
||||
def test_reject(self):
|
||||
"""
|
||||
Tests we can reject a socket and it won't complete the handshake.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.close",
|
||||
}
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.websocket_handshake(test_app)
|
||||
|
||||
def test_subprotocols(self):
|
||||
"""
|
||||
Tests that we can ask for subprotocols and then select one.
|
||||
"""
|
||||
subprotocols = ["proto1", "proto2"]
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": "proto2",
|
||||
}
|
||||
])
|
||||
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
|
||||
# Validate the scope and messages we got
|
||||
assert subprotocol == "proto2"
|
||||
scope, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_params=http_strategies.query_params(),
|
||||
request_headers=http_strategies.headers(),
|
||||
)
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
def test_http_bits(
|
||||
self,
|
||||
request_path,
|
||||
request_params,
|
||||
request_headers,
|
||||
):
|
||||
"""
|
||||
Tests that various HTTP-level bits (query string params, path, headers)
|
||||
carry over into the scope.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
self.websocket_handshake(
|
||||
test_app,
|
||||
path=request_path,
|
||||
params=request_params,
|
||||
headers=request_headers,
|
||||
)
|
||||
# Validate the scope and messages we got
|
||||
scope, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_scope(
|
||||
scope,
|
||||
path=request_path,
|
||||
params=request_params,
|
||||
headers=request_headers,
|
||||
)
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
|
||||
def test_text_frames(self):
|
||||
"""
|
||||
Tests we can send and receive text frames.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
# Connect
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
sock, _ = self.websocket_handshake(test_app)
|
||||
_, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
# Prep frame for it to send
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.send",
|
||||
"text": "here be dragons 🐉",
|
||||
}
|
||||
])
|
||||
# Send it a frame
|
||||
self.websocket_send_frame(sock, "what is here? 🌍")
|
||||
# Receive a frame and make sure it's correct
|
||||
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
|
||||
# Make sure it got our frame
|
||||
_, messages = test_app.get_received()
|
||||
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
|
||||
|
||||
def test_binary_frames(self):
|
||||
"""
|
||||
Tests we can send and receive binary frames with things that are very
|
||||
much not valid UTF-8.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
# Connect
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
sock, _ = self.websocket_handshake(test_app)
|
||||
_, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
# Prep frame for it to send
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.send",
|
||||
"bytes": b"here be \xe2 bytes",
|
||||
}
|
||||
])
|
||||
# Send it a frame
|
||||
self.websocket_send_frame(sock, b"what is here? \xe2")
|
||||
# Receive a frame and make sure it's correct
|
||||
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
|
||||
# Make sure it got our frame
|
||||
_, messages = test_app.get_received()
|
||||
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
|
Loading…
Reference in New Issue
Block a user