mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Add websocket tests to make sure everything important is covered.
This commit is contained in:
parent
1ca1c67032
commit
567c27504d
|
@ -1,3 +1,5 @@
|
||||||
|
from concurrent.futures import CancelledError
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -17,21 +19,29 @@ class TestApplication:
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
async def __call__(self, send, receive):
|
async def __call__(self, send, receive):
|
||||||
# Load setup info
|
|
||||||
setup = self.load_setup()
|
|
||||||
# Receive input and send output
|
# Receive input and send output
|
||||||
|
logging.debug("test app coroutine alive")
|
||||||
try:
|
try:
|
||||||
for _ in range(setup["receive_messages"]):
|
while True:
|
||||||
|
# Receive a message and save it into the result store
|
||||||
self.messages.append(await receive())
|
self.messages.append(await receive())
|
||||||
for message in setup["response_messages"]:
|
logging.debug("test app received %r", self.messages[-1])
|
||||||
await send(message)
|
self.save_result(self.scope, self.messages)
|
||||||
|
# See if there are any messages to send back
|
||||||
|
setup = self.load_setup()
|
||||||
|
self.delete_setup()
|
||||||
|
for message in setup["response_messages"]:
|
||||||
|
await send(message)
|
||||||
|
logging.debug("test app sent %r", message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.save_exception(e)
|
if isinstance(e, CancelledError):
|
||||||
else:
|
# Don't catch task-cancelled errors!
|
||||||
self.save_result()
|
raise
|
||||||
|
else:
|
||||||
|
self.save_exception(e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_setup(cls, response_messages, receive_messages=1):
|
def save_setup(cls, response_messages):
|
||||||
"""
|
"""
|
||||||
Stores setup information.
|
Stores setup information.
|
||||||
"""
|
"""
|
||||||
|
@ -39,7 +49,6 @@ class TestApplication:
|
||||||
pickle.dump(
|
pickle.dump(
|
||||||
{
|
{
|
||||||
"response_messages": response_messages,
|
"response_messages": response_messages,
|
||||||
"receive_messages": receive_messages,
|
|
||||||
},
|
},
|
||||||
fh,
|
fh,
|
||||||
)
|
)
|
||||||
|
@ -49,29 +58,34 @@ class TestApplication:
|
||||||
"""
|
"""
|
||||||
Returns setup details.
|
Returns setup details.
|
||||||
"""
|
"""
|
||||||
with open(cls.setup_storage, "rb") as fh:
|
try:
|
||||||
return pickle.load(fh)
|
with open(cls.setup_storage, "rb") as fh:
|
||||||
|
return pickle.load(fh)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {"response_messages": []}
|
||||||
|
|
||||||
def save_result(self):
|
@classmethod
|
||||||
|
def save_result(cls, scope, messages):
|
||||||
"""
|
"""
|
||||||
Saves details of what happened to the result storage.
|
Saves details of what happened to the result storage.
|
||||||
We could use pickle here, but that seems wrong, still, somehow.
|
We could use pickle here, but that seems wrong, still, somehow.
|
||||||
"""
|
"""
|
||||||
with open(self.result_storage, "wb") as fh:
|
with open(cls.result_storage, "wb") as fh:
|
||||||
pickle.dump(
|
pickle.dump(
|
||||||
{
|
{
|
||||||
"scope": self.scope,
|
"scope": scope,
|
||||||
"messages": self.messages,
|
"messages": messages,
|
||||||
},
|
},
|
||||||
fh,
|
fh,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_exception(self, exception):
|
@classmethod
|
||||||
|
def save_exception(cls, exception):
|
||||||
"""
|
"""
|
||||||
Saves details of what happened to the result storage.
|
Saves details of what happened to the result storage.
|
||||||
We could use pickle here, but that seems wrong, still, somehow.
|
We could use pickle here, but that seems wrong, still, somehow.
|
||||||
"""
|
"""
|
||||||
with open(self.result_storage, "wb") as fh:
|
with open(cls.result_storage, "wb") as fh:
|
||||||
pickle.dump(
|
pickle.dump(
|
||||||
{
|
{
|
||||||
"exception": exception,
|
"exception": exception,
|
||||||
|
@ -88,14 +102,20 @@ class TestApplication:
|
||||||
return pickle.load(fh)
|
return pickle.load(fh)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_storage(cls):
|
def delete_setup(cls):
|
||||||
"""
|
"""
|
||||||
Clears storage files.
|
Clears setup storage files.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
os.unlink(cls.setup_storage)
|
os.unlink(cls.setup_storage)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_result(cls):
|
||||||
|
"""
|
||||||
|
Clears result storage files.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
os.unlink(cls.result_storage)
|
os.unlink(cls.result_storage)
|
||||||
except OSError:
|
except OSError:
|
|
@ -1,128 +0,0 @@
|
||||||
from __future__ import unicode_literals
|
|
||||||
import six
|
|
||||||
from six.moves.urllib import parse
|
|
||||||
|
|
||||||
from asgiref.inmemory import ChannelLayer
|
|
||||||
from twisted.test import proto_helpers
|
|
||||||
|
|
||||||
from daphne.http_protocol import HTTPFactory
|
|
||||||
|
|
||||||
|
|
||||||
def message_for_request(method, path, params=None, headers=None, body=None):
|
|
||||||
"""
|
|
||||||
Constructs a HTTP request according to the given parameters, runs
|
|
||||||
that through daphne and returns the emitted channel message.
|
|
||||||
"""
|
|
||||||
request = _build_request(method, path, params, headers, body)
|
|
||||||
message, factory, transport = _run_through_daphne(request, "http.request")
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def response_for_message(message):
|
|
||||||
"""
|
|
||||||
Returns the raw HTTP response that Daphne constructs when sending a reply
|
|
||||||
to a HTTP request.
|
|
||||||
|
|
||||||
The current approach actually first builds a HTTP request (similar to
|
|
||||||
message_for_request) because we need a valid reply channel. I'm sure
|
|
||||||
this can be streamlined, but it works for now.
|
|
||||||
"""
|
|
||||||
request = _build_request("GET", "/")
|
|
||||||
request_message, factory, transport = _run_through_daphne(request, "http.request")
|
|
||||||
factory.dispatch_reply(request_message["reply_channel"], message)
|
|
||||||
return transport.value()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_request(method, path, params=None, headers=None, body=None):
|
|
||||||
"""
|
|
||||||
Takes request parameters and returns a byte string of a valid HTTP/1.1 request.
|
|
||||||
|
|
||||||
We really shouldn't manually build a HTTP request, and instead try to capture
|
|
||||||
what e.g. urllib or requests would do. But that is non-trivial, so meanwhile
|
|
||||||
we hope that our request building doesn't mask any errors.
|
|
||||||
|
|
||||||
This code is messy, because urllib behaves rather different between Python 2
|
|
||||||
and 3. Readability is further obstructed by the fact that Python 3.4 doesn't
|
|
||||||
support % formatting for bytes, so we need to concat everything.
|
|
||||||
If we run into more issues with this, the python-future library has a backport
|
|
||||||
of Python 3's urllib.
|
|
||||||
|
|
||||||
:param method: ASCII string of HTTP method.
|
|
||||||
:param path: unicode string of URL path.
|
|
||||||
:param params: List of two-tuples of bytestrings, ready for consumption for
|
|
||||||
urlencode. Encode to utf8 if necessary.
|
|
||||||
:param headers: List of two-tuples ASCII strings of HTTP header, value.
|
|
||||||
:param body: ASCII string of request body.
|
|
||||||
|
|
||||||
ASCII string is short for a unicode string containing only ASCII characters,
|
|
||||||
or a byte string with ASCII encoding.
|
|
||||||
"""
|
|
||||||
if headers is None:
|
|
||||||
headers = []
|
|
||||||
else:
|
|
||||||
headers = headers[:]
|
|
||||||
|
|
||||||
if six.PY3:
|
|
||||||
quoted_path = parse.quote(path)
|
|
||||||
if params:
|
|
||||||
quoted_path += "?" + parse.urlencode(params)
|
|
||||||
quoted_path = quoted_path.encode("ascii")
|
|
||||||
else:
|
|
||||||
quoted_path = parse.quote(path.encode("utf8"))
|
|
||||||
if params:
|
|
||||||
quoted_path += b"?" + parse.urlencode(params)
|
|
||||||
|
|
||||||
request = method.encode("ascii") + b" " + quoted_path + b" HTTP/1.1\r\n"
|
|
||||||
for name, value in headers:
|
|
||||||
request += header_line(name, value)
|
|
||||||
|
|
||||||
request += b"\r\n"
|
|
||||||
|
|
||||||
if body:
|
|
||||||
request += body.encode("ascii")
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
def build_websocket_upgrade(path, params, headers):
|
|
||||||
ws_headers = [
|
|
||||||
("Host", "somewhere.com"),
|
|
||||||
("Upgrade", "websocket"),
|
|
||||||
("Connection", "Upgrade"),
|
|
||||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
|
||||||
("Sec-WebSocket-Protocol", "chat, superchat"),
|
|
||||||
("Sec-WebSocket-Version", "13"),
|
|
||||||
("Origin", "http://example.com")
|
|
||||||
]
|
|
||||||
return _build_request("GET", path, params, headers=headers + ws_headers, body=None)
|
|
||||||
|
|
||||||
|
|
||||||
def header_line(name, value):
|
|
||||||
"""
|
|
||||||
Given a header name and value, returns the line to use in a HTTP request or response.
|
|
||||||
"""
|
|
||||||
return name.encode("ascii") + b": " + value.encode("ascii") + b"\r\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _run_through_daphne(request, channel_name):
|
|
||||||
"""
|
|
||||||
Returns Daphne's channel message for a given request.
|
|
||||||
|
|
||||||
This helper requires a fair bit of scaffolding and can certainly be improved,
|
|
||||||
but it works for now.
|
|
||||||
"""
|
|
||||||
channel_layer = ChannelLayer()
|
|
||||||
factory = HTTPFactory(channel_layer, send_channel="test!")
|
|
||||||
proto = factory.buildProtocol(("127.0.0.1", 0))
|
|
||||||
transport = proto_helpers.StringTransport()
|
|
||||||
proto.makeConnection(transport)
|
|
||||||
proto.dataReceived(request)
|
|
||||||
_, message = channel_layer.receive([channel_name])
|
|
||||||
return message, factory, transport
|
|
||||||
|
|
||||||
|
|
||||||
def content_length_header(body):
|
|
||||||
"""
|
|
||||||
Returns an appropriate Content-Length HTTP header for a given body.
|
|
||||||
"""
|
|
||||||
return "Content-Length", six.text_type(len(body))
|
|
|
@ -1,246 +0,0 @@
|
||||||
# coding: utf8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from hypothesis import assume, given, strategies, settings
|
|
||||||
from twisted.test import proto_helpers
|
|
||||||
|
|
||||||
from asgiref.inmemory import ChannelLayer
|
|
||||||
from daphne.http_protocol import HTTPFactory
|
|
||||||
from daphne.tests import http_strategies, testcases, factories
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConnection(object):
|
|
||||||
"""
|
|
||||||
Helper class that makes it easier to test Dahpne's WebSocket support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.last_message = None
|
|
||||||
|
|
||||||
self.channel_layer = ChannelLayer()
|
|
||||||
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
|
||||||
self.proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
|
||||||
self.transport = proto_helpers.StringTransport()
|
|
||||||
self.proto.makeConnection(self.transport)
|
|
||||||
|
|
||||||
def receive(self, request):
|
|
||||||
"""
|
|
||||||
Low-level method to let Daphne handle HTTP/WebSocket data
|
|
||||||
"""
|
|
||||||
self.proto.dataReceived(request)
|
|
||||||
_, self.last_message = self.channel_layer.receive(["websocket.connect"])
|
|
||||||
return self.last_message
|
|
||||||
|
|
||||||
def send(self, content):
|
|
||||||
"""
|
|
||||||
Method to respond with a channel message
|
|
||||||
"""
|
|
||||||
if self.last_message is None:
|
|
||||||
# Auto-connect for convenience.
|
|
||||||
self.connect()
|
|
||||||
self.factory.dispatch_reply(self.last_message["reply_channel"], content)
|
|
||||||
response = self.transport.value()
|
|
||||||
self.transport.clear()
|
|
||||||
return response
|
|
||||||
|
|
||||||
def connect(self, path="/", params=None, headers=None):
|
|
||||||
"""
|
|
||||||
High-level method to perform the WebSocket handshake
|
|
||||||
"""
|
|
||||||
request = factories.build_websocket_upgrade(path, params, headers or [])
|
|
||||||
message = self.receive(request)
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class TestHandshake(testcases.ASGIWebSocketTestCase):
|
|
||||||
"""
|
|
||||||
Tests for the WebSocket handshake
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_minimal(self):
|
|
||||||
message = WebSocketConnection().connect()
|
|
||||||
self.assert_valid_websocket_connect_message(message)
|
|
||||||
|
|
||||||
@given(
|
|
||||||
path=http_strategies.http_path(),
|
|
||||||
params=http_strategies.query_params(),
|
|
||||||
headers=http_strategies.headers(),
|
|
||||||
)
|
|
||||||
@settings(perform_health_check=False)
|
|
||||||
def test_connection(self, path, params, headers):
|
|
||||||
message = WebSocketConnection().connect(path, params, headers)
|
|
||||||
self.assert_valid_websocket_connect_message(message, path, params, headers)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
|
|
||||||
"""
|
|
||||||
Tests that, essentially, try to translate the send/close/accept section of the spec into code.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_empty_accept(self):
|
|
||||||
response = WebSocketConnection().send({"accept": True})
|
|
||||||
self.assert_websocket_upgrade(response)
|
|
||||||
|
|
||||||
@given(text=http_strategies.http_body())
|
|
||||||
def test_accept_and_text(self, text):
|
|
||||||
response = WebSocketConnection().send({"accept": True, "text": text})
|
|
||||||
self.assert_websocket_upgrade(response, text.encode("ascii"))
|
|
||||||
|
|
||||||
@given(data=http_strategies.binary_payload())
|
|
||||||
def test_accept_and_bytes(self, data):
|
|
||||||
response = WebSocketConnection().send({"accept": True, "bytes": data})
|
|
||||||
self.assert_websocket_upgrade(response, data)
|
|
||||||
|
|
||||||
def test_accept_false(self):
|
|
||||||
response = WebSocketConnection().send({"accept": False})
|
|
||||||
self.assert_websocket_denied(response)
|
|
||||||
|
|
||||||
def test_accept_false_with_text(self):
|
|
||||||
"""
|
|
||||||
Tests that even if text is given, the connection is denied.
|
|
||||||
|
|
||||||
We can't easily use Hypothesis to generate data for this test because it's
|
|
||||||
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
|
||||||
"""
|
|
||||||
text = "foobar"
|
|
||||||
response = WebSocketConnection().send({"accept": False, "text": text})
|
|
||||||
self.assert_websocket_denied(response)
|
|
||||||
self.assertNotIn(text.encode("ascii"), response)
|
|
||||||
|
|
||||||
def test_accept_false_with_bytes(self):
|
|
||||||
"""
|
|
||||||
Tests that even if data is given, the connection is denied.
|
|
||||||
|
|
||||||
We can't easily use Hypothesis to generate data for this test because it's
|
|
||||||
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
|
||||||
"""
|
|
||||||
data = b"foobar"
|
|
||||||
response = WebSocketConnection().send({"accept": False, "bytes": data})
|
|
||||||
self.assert_websocket_denied(response)
|
|
||||||
self.assertNotIn(data, response)
|
|
||||||
|
|
||||||
@given(text=http_strategies.http_body())
|
|
||||||
def test_just_text(self, text):
|
|
||||||
assume(len(text) > 0)
|
|
||||||
# If content is sent, accept=True is implied.
|
|
||||||
response = WebSocketConnection().send({"text": text})
|
|
||||||
self.assert_websocket_upgrade(response, text.encode("ascii"))
|
|
||||||
|
|
||||||
@given(data=http_strategies.binary_payload())
|
|
||||||
def test_just_bytes(self, data):
|
|
||||||
assume(len(data) > 0)
|
|
||||||
# If content is sent, accept=True is implied.
|
|
||||||
response = WebSocketConnection().send({"bytes": data})
|
|
||||||
self.assert_websocket_upgrade(response, data)
|
|
||||||
|
|
||||||
def test_close_boolean(self):
|
|
||||||
response = WebSocketConnection().send({"close": True})
|
|
||||||
self.assert_websocket_denied(response)
|
|
||||||
|
|
||||||
@given(number=strategies.integers(min_value=1))
|
|
||||||
def test_close_integer(self, number):
|
|
||||||
response = WebSocketConnection().send({"close": number})
|
|
||||||
self.assert_websocket_denied(response)
|
|
||||||
|
|
||||||
@given(text=http_strategies.http_body())
|
|
||||||
def test_close_with_text(self, text):
|
|
||||||
assume(len(text) > 0)
|
|
||||||
response = WebSocketConnection().send({"close": True, "text": text})
|
|
||||||
self.assert_websocket_upgrade(response, text.encode("ascii"), expect_close=True)
|
|
||||||
|
|
||||||
@given(data=http_strategies.binary_payload())
|
|
||||||
def test_close_with_data(self, data):
|
|
||||||
assume(len(data) > 0)
|
|
||||||
response = WebSocketConnection().send({"close": True, "bytes": data})
|
|
||||||
self.assert_websocket_upgrade(response, data, expect_close=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
|
|
||||||
"""
|
|
||||||
Tests that the WebSocket protocol class correctly generates and parses messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.connection = WebSocketConnection()
|
|
||||||
|
|
||||||
def test_basic(self):
|
|
||||||
# Send a simple request to the protocol and get the resulting message off
|
|
||||||
# of the channel layer.
|
|
||||||
message = self.connection.receive(
|
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
|
||||||
b"Host: somewhere.com\r\n"
|
|
||||||
b"Upgrade: websocket\r\n"
|
|
||||||
b"Connection: Upgrade\r\n"
|
|
||||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
|
||||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
|
||||||
b"Sec-WebSocket-Version: 13\r\n"
|
|
||||||
b"Origin: http://example.com\r\n"
|
|
||||||
b"\r\n"
|
|
||||||
)
|
|
||||||
self.assertEqual(message["path"], "/chat")
|
|
||||||
self.assertEqual(message["query_string"], b"")
|
|
||||||
self.assertEqual(
|
|
||||||
sorted(message["headers"]),
|
|
||||||
[(b"connection", b"Upgrade"),
|
|
||||||
(b"host", b"somewhere.com"),
|
|
||||||
(b"origin", b"http://example.com"),
|
|
||||||
(b"sec-websocket-key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
|
|
||||||
(b"sec-websocket-protocol", b"chat, superchat"),
|
|
||||||
(b"sec-websocket-version", b"13"),
|
|
||||||
(b"upgrade", b"websocket")]
|
|
||||||
)
|
|
||||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
|
||||||
|
|
||||||
# Accept the connection
|
|
||||||
response = self.connection.send({"accept": True})
|
|
||||||
self.assert_websocket_upgrade(response)
|
|
||||||
|
|
||||||
# Send some text
|
|
||||||
response = self.connection.send({"text": "Hello World!"})
|
|
||||||
self.assertEqual(response, b"\x81\x0cHello World!")
|
|
||||||
|
|
||||||
# Send some bytes
|
|
||||||
response = self.connection.send({"bytes": b"\xaa\xbb\xcc\xdd"})
|
|
||||||
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
|
|
||||||
|
|
||||||
# Close the connection
|
|
||||||
response = self.connection.send({"close": True})
|
|
||||||
self.assertEqual(response, b"\x88\x02\x03\xe8")
|
|
||||||
|
|
||||||
def test_connection_with_file_origin_is_accepted(self):
|
|
||||||
message = self.connection.receive(
|
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
|
||||||
b"Host: somewhere.com\r\n"
|
|
||||||
b"Upgrade: websocket\r\n"
|
|
||||||
b"Connection: Upgrade\r\n"
|
|
||||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
|
||||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
|
||||||
b"Sec-WebSocket-Version: 13\r\n"
|
|
||||||
b"Origin: file://\r\n"
|
|
||||||
b"\r\n"
|
|
||||||
)
|
|
||||||
self.assertIn((b"origin", b"file://"), message["headers"])
|
|
||||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
|
||||||
|
|
||||||
# Accept the connection
|
|
||||||
response = self.connection.send({"accept": True})
|
|
||||||
self.assert_websocket_upgrade(response)
|
|
||||||
|
|
||||||
def test_connection_with_no_origin_is_accepted(self):
|
|
||||||
message = self.connection.receive(
|
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
|
||||||
b"Host: somewhere.com\r\n"
|
|
||||||
b"Upgrade: websocket\r\n"
|
|
||||||
b"Connection: Upgrade\r\n"
|
|
||||||
b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
|
|
||||||
b"Sec-WebSocket-Protocol: chat, superchat\r\n"
|
|
||||||
b"Sec-WebSocket-Version: 13\r\n"
|
|
||||||
b"\r\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertNotIn(b"origin", [header_tuple[0] for header_tuple in message["headers"]])
|
|
||||||
self.assert_valid_websocket_connect_message(message, "/chat")
|
|
||||||
|
|
||||||
# Accept the connection
|
|
||||||
response = self.connection.send({"accept": True})
|
|
||||||
self.assert_websocket_upgrade(response)
|
|
|
@ -73,7 +73,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
"client": self.client_addr,
|
"client": self.client_addr,
|
||||||
"server": self.server_addr,
|
"server": self.server_addr,
|
||||||
"subprotocols": subprotocols,
|
"subprotocols": subprotocols,
|
||||||
"order": 0,
|
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
# Exceptions here are not displayed right, just 500.
|
# Exceptions here are not displayed right, just 500.
|
||||||
|
|
|
@ -1,22 +1,26 @@
|
||||||
from urllib import parse
|
|
||||||
from http.client import HTTPConnection
|
from http.client import HTTPConnection
|
||||||
|
from urllib import parse
|
||||||
import socket
|
import socket
|
||||||
|
import struct
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from daphne.test_utils import TestApplication
|
from daphne.test_application import TestApplication
|
||||||
|
|
||||||
|
|
||||||
class DaphneTestCase(unittest.TestCase):
|
class DaphneTestingInstance:
|
||||||
"""
|
"""
|
||||||
Base class for Daphne integration test cases.
|
Launches an instance of Daphne to test against, with an application
|
||||||
|
object you can read messages from and feed messages to.
|
||||||
|
|
||||||
Boots up a copy of Daphne on a test port and sends it a request, and
|
Works as a context manager.
|
||||||
retrieves the response. Uses a custom ASGI application and temporary files
|
|
||||||
to store/retrieve the request/response messages.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, xff=False):
|
||||||
|
self.xff = xff
|
||||||
|
self.host = "127.0.0.1"
|
||||||
|
|
||||||
def port_in_use(self, port):
|
def port_in_use(self, port):
|
||||||
"""
|
"""
|
||||||
Tests if a port is in use on the local machine.
|
Tests if a port is in use on the local machine.
|
||||||
|
@ -34,39 +38,90 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
finally:
|
finally:
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
def run_daphne(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
def find_free_port(self):
|
||||||
|
"""
|
||||||
|
Finds an unused port to test stuff on
|
||||||
|
"""
|
||||||
|
for i in range(11200, 11300):
|
||||||
|
if not self.port_in_use(i):
|
||||||
|
return i
|
||||||
|
raise RuntimeError("Cannot find a free port to test on")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# Clear result storage
|
||||||
|
TestApplication.delete_setup()
|
||||||
|
TestApplication.delete_result()
|
||||||
|
# Find a port to listen on
|
||||||
|
self.port = self.find_free_port()
|
||||||
|
daphne_args = ["daphne", "-p", str(self.port), "-v", "0"]
|
||||||
|
# Optionally enable X-Forwarded-For support.
|
||||||
|
if self.xff:
|
||||||
|
daphne_args += ["--proxy-headers"]
|
||||||
|
# Start up process and make sure it begins listening.
|
||||||
|
self.process = subprocess.Popen(daphne_args + ["daphne.test_application:TestApplication"])
|
||||||
|
for _ in range(100):
|
||||||
|
time.sleep(0.1)
|
||||||
|
if self.port_in_use(self.port):
|
||||||
|
return self
|
||||||
|
# Daphne didn't start up. Sadface.
|
||||||
|
self.process.terminate()
|
||||||
|
raise RuntimeError("Daphne never came up.")
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
# Shut down the process
|
||||||
|
self.process.terminate()
|
||||||
|
del self.process
|
||||||
|
|
||||||
|
def get_received(self):
|
||||||
|
"""
|
||||||
|
Returns the scope and messages the test application has received
|
||||||
|
so far. Note you'll get all messages since scope start, not just any
|
||||||
|
new ones since the last call.
|
||||||
|
|
||||||
|
Also checks for any exceptions in the application. If there are,
|
||||||
|
raises them.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
inner_result = TestApplication.load_result()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError("No results available yet.")
|
||||||
|
# Check for exception
|
||||||
|
if "exception" in inner_result:
|
||||||
|
raise inner_result["exception"]
|
||||||
|
return inner_result["scope"], inner_result["messages"]
|
||||||
|
|
||||||
|
def add_send_messages(self, messages):
|
||||||
|
"""
|
||||||
|
Adds messages for the application to send back.
|
||||||
|
The next time it receives an incoming message, it will reply with these.
|
||||||
|
"""
|
||||||
|
TestApplication.save_setup(
|
||||||
|
response_messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DaphneTestCase(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Base class for Daphne integration test cases.
|
||||||
|
|
||||||
|
Boots up a copy of Daphne on a test port and sends it a request, and
|
||||||
|
retrieves the response. Uses a custom ASGI application and temporary files
|
||||||
|
to store/retrieve the request/response messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
### Plain HTTP helpers
|
||||||
|
|
||||||
|
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
||||||
"""
|
"""
|
||||||
Runs Daphne with the given request callback (given the base URL)
|
Runs Daphne with the given request callback (given the base URL)
|
||||||
and response messages.
|
and response messages.
|
||||||
"""
|
"""
|
||||||
# Store setup info
|
with DaphneTestingInstance(xff=xff) as test_app:
|
||||||
TestApplication.clear_storage()
|
# Add the response messages
|
||||||
TestApplication.save_setup(
|
test_app.add_send_messages(responses)
|
||||||
response_messages=responses,
|
|
||||||
)
|
|
||||||
# Find a free port
|
|
||||||
for i in range(11200, 11300):
|
|
||||||
if not self.port_in_use(i):
|
|
||||||
port = i
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Cannot find a free port to test on")
|
|
||||||
# Launch daphne on that port
|
|
||||||
daphne_args = ["daphne", "-p", str(port), "-v", "0"]
|
|
||||||
if xff:
|
|
||||||
# Optionally enable X-Forwarded-For support.
|
|
||||||
daphne_args += ["--proxy-headers"]
|
|
||||||
process = subprocess.Popen(daphne_args + ["daphne.test_utils:TestApplication"])
|
|
||||||
try:
|
|
||||||
for _ in range(100):
|
|
||||||
time.sleep(0.1)
|
|
||||||
if self.port_in_use(port):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Daphne never came up.")
|
|
||||||
# Send it the request. We have to do this the long way to allow
|
# Send it the request. We have to do this the long way to allow
|
||||||
# duplicate headers.
|
# duplicate headers.
|
||||||
conn = HTTPConnection("127.0.0.1", port, timeout=timeout)
|
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
||||||
# Make sure path is urlquoted and add any params
|
# Make sure path is urlquoted and add any params
|
||||||
path = parse.quote(path)
|
path = parse.quote(path)
|
||||||
if params:
|
if params:
|
||||||
|
@ -86,29 +141,17 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
response = conn.getresponse()
|
response = conn.getresponse()
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
# See if they left an exception for us to load
|
# See if they left an exception for us to load
|
||||||
try:
|
test_app.get_received()
|
||||||
exception_result = TestApplication.load_result()
|
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
||||||
except OSError:
|
# Return scope, messages, response
|
||||||
raise RuntimeError("Daphne timed out handling request, no result file")
|
return test_app.get_received() + (response, )
|
||||||
else:
|
|
||||||
if "exception" in exception_result:
|
|
||||||
raise exception_result["exception"]
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Daphne timed out handling request, no exception found: %r" % exception_result)
|
|
||||||
finally:
|
|
||||||
# Shut down daphne
|
|
||||||
process.terminate()
|
|
||||||
# Load the information
|
|
||||||
inner_result = TestApplication.load_result()
|
|
||||||
# Return the inner result and the response
|
|
||||||
return inner_result, response
|
|
||||||
|
|
||||||
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
|
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
|
||||||
"""
|
"""
|
||||||
Convenience method for just testing request handling.
|
Convenience method for just testing request handling.
|
||||||
Returns (scope, messages)
|
Returns (scope, messages)
|
||||||
"""
|
"""
|
||||||
inner_result, _ = self.run_daphne(
|
scope, messages, _ = self.run_daphne_http(
|
||||||
method=method,
|
method=method,
|
||||||
path=path,
|
path=path,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -117,14 +160,14 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
xff=xff,
|
xff=xff,
|
||||||
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
|
responses=[{"type": "http.response", "status": 200, "content": b"OK"}],
|
||||||
)
|
)
|
||||||
return inner_result["scope"], inner_result["messages"]
|
return scope, messages
|
||||||
|
|
||||||
def run_daphne_response(self, response_messages):
|
def run_daphne_response(self, response_messages):
|
||||||
"""
|
"""
|
||||||
Convenience method for just testing response handling.
|
Convenience method for just testing response handling.
|
||||||
Returns (scope, messages)
|
Returns (scope, messages)
|
||||||
"""
|
"""
|
||||||
_, response = self.run_daphne(
|
_, _, response = self.run_daphne_http(
|
||||||
method="GET",
|
method="GET",
|
||||||
path="/",
|
path="/",
|
||||||
params={},
|
params={},
|
||||||
|
@ -133,11 +176,119 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
### WebSocket helpers
|
||||||
|
|
||||||
|
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
|
||||||
|
"""
|
||||||
|
Runs a WebSocket handshake negotiation and returns the raw socket
|
||||||
|
object & the selected subprotocol.
|
||||||
|
|
||||||
|
You'll need to inject an accept or reject message before this
|
||||||
|
to let it complete.
|
||||||
|
"""
|
||||||
|
# Send it the request. We have to do this the long way to allow
|
||||||
|
# duplicate headers.
|
||||||
|
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
|
||||||
|
# Make sure path is urlquoted and add any params
|
||||||
|
path = parse.quote(path)
|
||||||
|
if params:
|
||||||
|
path += "?" + parse.urlencode(params, doseq=True)
|
||||||
|
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
|
||||||
|
# Do WebSocket handshake headers + any other headers
|
||||||
|
if headers is None:
|
||||||
|
headers = []
|
||||||
|
headers.extend([
|
||||||
|
("Host", "example.com"),
|
||||||
|
("Upgrade", "websocket"),
|
||||||
|
("Connection", "Upgrade"),
|
||||||
|
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||||
|
("Sec-WebSocket-Version", "13"),
|
||||||
|
("Origin", "http://example.com")
|
||||||
|
])
|
||||||
|
if subprotocols:
|
||||||
|
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
||||||
|
if headers:
|
||||||
|
for header_name, header_value in headers:
|
||||||
|
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
|
||||||
|
conn.endheaders()
|
||||||
|
# Read out the response
|
||||||
|
try:
|
||||||
|
response = conn.getresponse()
|
||||||
|
except socket.timeout:
|
||||||
|
# See if they left an exception for us to load
|
||||||
|
test_app.get_received()
|
||||||
|
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
||||||
|
# Check we got a good response code
|
||||||
|
if response.status != 101:
|
||||||
|
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
||||||
|
# Prepare headers for subprotocol searching
|
||||||
|
response_headers = dict(
|
||||||
|
(n.lower(), v)
|
||||||
|
for n, v in response.getheaders()
|
||||||
|
)
|
||||||
|
response.read()
|
||||||
|
assert not response.closed
|
||||||
|
# Return the raw socket and any subprotocol
|
||||||
|
return conn.sock, response_headers.get("sec-websocket-protocol", None)
|
||||||
|
|
||||||
|
def websocket_send_frame(self, sock, value):
|
||||||
|
"""
|
||||||
|
Sends a WebSocket text or binary frame. Cannot handle long frames.
|
||||||
|
"""
|
||||||
|
# Header and text opcode
|
||||||
|
if isinstance(value, str):
|
||||||
|
frame = b"\x81"
|
||||||
|
value = value.encode("utf8")
|
||||||
|
else:
|
||||||
|
frame = b"\x82"
|
||||||
|
# Length plus masking signal bit
|
||||||
|
frame += struct.pack("!B", len(value) | 0b10000000)
|
||||||
|
# Mask badly
|
||||||
|
frame += b"\0\0\0\0"
|
||||||
|
# Payload
|
||||||
|
frame += value
|
||||||
|
print("sending %r" % frame)
|
||||||
|
sock.sendall(frame)
|
||||||
|
|
||||||
|
def receive_from_socket(self, sock, length, timeout=1):
|
||||||
|
"""
|
||||||
|
Receives the given amount of bytes from the socket, or times out.
|
||||||
|
"""
|
||||||
|
buf = b""
|
||||||
|
started = time.time()
|
||||||
|
while len(buf) < length:
|
||||||
|
buf += sock.recv(length - len(buf))
|
||||||
|
time.sleep(0.001)
|
||||||
|
if time.time() - started > timeout:
|
||||||
|
raise ValueError("Timed out reading from socket")
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def websocket_receive_frame(self, sock):
|
||||||
|
"""
|
||||||
|
Receives a WebSocket frame. Cannot handle long frames.
|
||||||
|
"""
|
||||||
|
# Read header byte
|
||||||
|
# TODO: Proper receive buffer handling
|
||||||
|
opcode = self.receive_from_socket(sock, 1)
|
||||||
|
if opcode in [b"\x81", b"\x82"]:
|
||||||
|
# Read length
|
||||||
|
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
|
||||||
|
# Read payload
|
||||||
|
payload = self.receive_from_socket(sock, length)
|
||||||
|
if opcode == b"\x81":
|
||||||
|
payload = payload.decode("utf8")
|
||||||
|
return payload
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown websocket opcode: %r" % opcode)
|
||||||
|
|
||||||
|
### Assertions and test management
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""
|
"""
|
||||||
Ensures any storage files are cleared.
|
Ensures any storage files are cleared.
|
||||||
"""
|
"""
|
||||||
TestApplication.clear_storage()
|
TestApplication.delete_setup()
|
||||||
|
TestApplication.delete_result()
|
||||||
|
|
||||||
def assert_is_ip_address(self, address):
|
def assert_is_ip_address(self, address):
|
||||||
"""
|
"""
|
||||||
|
@ -179,85 +330,3 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
self.assertIsInstance(address, str)
|
self.assertIsInstance(address, str)
|
||||||
self.assert_is_ip_address(address)
|
self.assert_is_ip_address(address)
|
||||||
self.assertIsInstance(port, int)
|
self.assertIsInstance(port, int)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# class ASGIWebSocketTestCase(ASGITestCaseBase):
|
|
||||||
# """
|
|
||||||
# Test case with helpers for verifying WebSocket channel messages
|
|
||||||
# """
|
|
||||||
|
|
||||||
# def assert_websocket_upgrade(self, response, body=b"", expect_close=False):
|
|
||||||
# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
|
||||||
# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
|
||||||
# self.assertIn(body, response)
|
|
||||||
# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
|
||||||
|
|
||||||
# def assert_websocket_denied(self, response):
|
|
||||||
# self.assertIn(b"HTTP/1.1 403", response)
|
|
||||||
|
|
||||||
# def assert_valid_websocket_connect_message(
|
|
||||||
# self, channel_message, request_path="/", request_params=None, request_headers=None):
|
|
||||||
# """
|
|
||||||
# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# self.assertTrue(channel_message)
|
|
||||||
|
|
||||||
# self.assert_presence_of_message_keys(
|
|
||||||
# channel_message.keys(),
|
|
||||||
# {"reply_channel", "path", "headers", "order"},
|
|
||||||
# {"scheme", "query_string", "root_path", "client", "server"})
|
|
||||||
|
|
||||||
# # == Assertions about required channel_message fields ==
|
|
||||||
# self.assert_valid_reply_channel(channel_message["reply_channel"])
|
|
||||||
# self.assert_valid_path(channel_message["path"], request_path)
|
|
||||||
|
|
||||||
# order = channel_message["order"]
|
|
||||||
# self.assertIsInstance(order, int)
|
|
||||||
# self.assertEqual(order, 0)
|
|
||||||
|
|
||||||
# # Ordering of header names is not important, but the order of values for a header
|
|
||||||
# # name is. To assert whether that order is kept, we transform the request
|
|
||||||
# # headers and the channel message headers into a set
|
|
||||||
# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
|
||||||
# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
|
||||||
# # get one string per header field with values separated by comma.
|
|
||||||
# transformed_request_headers = defaultdict(list)
|
|
||||||
# for name, value in (request_headers or []):
|
|
||||||
# expected_name = name.lower().strip().encode("ascii")
|
|
||||||
# expected_value = value.strip().encode("ascii")
|
|
||||||
# transformed_request_headers[expected_name].append(expected_value)
|
|
||||||
# final_request_headers = {
|
|
||||||
# (name, b",".join(value)) for name, value in transformed_request_headers.items()
|
|
||||||
# }
|
|
||||||
|
|
||||||
# # Websockets carry a lot of additional header fields, so instead of verifying that
|
|
||||||
# # headers look exactly like expected, we just check that the expected header fields
|
|
||||||
# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
|
||||||
# # and not tested for.
|
|
||||||
# assert final_request_headers.issubset(set(channel_message["headers"]))
|
|
||||||
|
|
||||||
# # == Assertions about optional channel_message fields ==
|
|
||||||
# scheme = channel_message.get("scheme")
|
|
||||||
# if scheme:
|
|
||||||
# self.assertIsInstance(scheme, six.text_type)
|
|
||||||
# self.assertIn(scheme, ["ws", "wss"])
|
|
||||||
|
|
||||||
# query_string = channel_message.get("query_string")
|
|
||||||
# if query_string:
|
|
||||||
# # Assert that query_string is a byte string and still url encoded
|
|
||||||
# self.assertIsInstance(query_string, six.binary_type)
|
|
||||||
# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii"))
|
|
||||||
|
|
||||||
# root_path = channel_message.get("root_path")
|
|
||||||
# if root_path is not None:
|
|
||||||
# self.assertIsInstance(root_path, six.text_type)
|
|
||||||
|
|
||||||
# client = channel_message.get("client")
|
|
||||||
# if client is not None:
|
|
||||||
# self.assert_valid_address_and_port(channel_message["client"])
|
|
||||||
|
|
||||||
# server = channel_message.get("server")
|
|
||||||
# if server is not None:
|
|
||||||
# self.assert_valid_address_and_port(channel_message["server"])
|
|
||||||
|
|
|
@ -207,7 +207,7 @@ class TestHTTPRequestSpec(DaphneTestCase):
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
# Note that Daphne returns a list of tuples here, which is fine, because the spec
|
# Note that Daphne returns a list of tuples here, which is fine, because the spec
|
||||||
# asks to treat them interchangeably.
|
# asks to treat them interchangeably.
|
||||||
assert scope["headers"] == [[b"mycustomheader", b"foobar"]]
|
assert [list(x) for x in scope["headers"]] == [[b"mycustomheader", b"foobar"]]
|
||||||
|
|
||||||
@given(daphne_path=http_strategies.http_path())
|
@given(daphne_path=http_strategies.http_path())
|
||||||
@settings(max_examples=5, deadline=2000)
|
@settings(max_examples=5, deadline=2000)
|
||||||
|
|
239
tests/test_websocket.py
Normal file
239
tests/test_websocket.py
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from urllib import parse
|
||||||
|
|
||||||
|
from hypothesis import given, settings
|
||||||
|
|
||||||
|
import http_strategies
|
||||||
|
from http_base import DaphneTestCase, DaphneTestingInstance
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebsocket(DaphneTestCase):
|
||||||
|
"""
|
||||||
|
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
||||||
|
The heavy lifting is done by the assert_valid_http_request_message function,
|
||||||
|
the tests mostly serve to wire up hypothesis so that it exercise it's power to find
|
||||||
|
edge cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def assert_valid_websocket_scope(
|
||||||
|
self,
|
||||||
|
scope,
|
||||||
|
path="/",
|
||||||
|
params=None,
|
||||||
|
headers=None,
|
||||||
|
scheme=None,
|
||||||
|
subprotocols=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||||
|
and some urlencoding things.
|
||||||
|
"""
|
||||||
|
# Check overall keys
|
||||||
|
self.assert_key_sets(
|
||||||
|
required_keys={"type", "path", "query_string", "headers"},
|
||||||
|
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
|
||||||
|
actual_keys=scope.keys(),
|
||||||
|
)
|
||||||
|
# Check that it is the right type
|
||||||
|
self.assertEqual(scope["type"], "websocket")
|
||||||
|
# Path
|
||||||
|
self.assert_valid_path(scope["path"], path)
|
||||||
|
# Scheme
|
||||||
|
self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
|
||||||
|
if scheme:
|
||||||
|
self.assertEqual(scheme, scope["scheme"])
|
||||||
|
# Query string (byte string and still url encoded)
|
||||||
|
query_string = scope["query_string"]
|
||||||
|
self.assertIsInstance(query_string, bytes)
|
||||||
|
if params:
|
||||||
|
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
||||||
|
# Ordering of header names is not important, but the order of values for a header
|
||||||
|
# name is. To assert whether that order is kept, we transform both the request
|
||||||
|
# headers and the channel message headers into a dictionary
|
||||||
|
# {name: [value1, value2, ...]} and check if they're equal.
|
||||||
|
transformed_scope_headers = collections.defaultdict(list)
|
||||||
|
for name, value in scope["headers"]:
|
||||||
|
transformed_scope_headers[name].append(value)
|
||||||
|
transformed_request_headers = collections.defaultdict(list)
|
||||||
|
for name, value in (headers or []):
|
||||||
|
expected_name = name.lower().strip().encode("ascii")
|
||||||
|
expected_value = value.strip().encode("ascii")
|
||||||
|
transformed_request_headers[expected_name].append(expected_value)
|
||||||
|
for name, value in transformed_request_headers.items():
|
||||||
|
self.assertIn(name, transformed_scope_headers)
|
||||||
|
self.assertEqual(value, transformed_scope_headers[name])
|
||||||
|
# Root path
|
||||||
|
self.assertIsInstance(scope.get("root_path", ""), str)
|
||||||
|
# Client and server addresses
|
||||||
|
client = scope.get("client")
|
||||||
|
if client is not None:
|
||||||
|
self.assert_valid_address_and_port(client)
|
||||||
|
server = scope.get("server")
|
||||||
|
if server is not None:
|
||||||
|
self.assert_valid_address_and_port(server)
|
||||||
|
# Subprotocols
|
||||||
|
scope_subprotocols = scope.get("subprotocols", [])
|
||||||
|
if scope_subprotocols:
|
||||||
|
assert all(isinstance(x, str) for x in scope_subprotocols)
|
||||||
|
if subprotocols:
|
||||||
|
assert sorted(scope_subprotocols) == sorted(subprotocols)
|
||||||
|
|
||||||
|
def assert_valid_websocket_connect_message(self, message):
|
||||||
|
"""
|
||||||
|
Asserts that a message is a valid http.request message
|
||||||
|
"""
|
||||||
|
# Check overall keys
|
||||||
|
self.assert_key_sets(
|
||||||
|
required_keys={"type"},
|
||||||
|
optional_keys=set(),
|
||||||
|
actual_keys=message.keys(),
|
||||||
|
)
|
||||||
|
# Check that it is the right type
|
||||||
|
self.assertEqual(message["type"], "websocket.connect")
|
||||||
|
|
||||||
|
def test_accept(self):
|
||||||
|
"""
|
||||||
|
Tests we can open and accept a socket.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
self.websocket_handshake(test_app)
|
||||||
|
# Validate the scope and messages we got
|
||||||
|
scope, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_scope(scope)
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
def test_reject(self):
|
||||||
|
"""
|
||||||
|
Tests we can reject a socket and it won't complete the handshake.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.close",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
self.websocket_handshake(test_app)
|
||||||
|
|
||||||
|
def test_subprotocols(self):
|
||||||
|
"""
|
||||||
|
Tests that we can ask for subprotocols and then select one.
|
||||||
|
"""
|
||||||
|
subprotocols = ["proto1", "proto2"]
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
"subprotocol": "proto2",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
|
||||||
|
# Validate the scope and messages we got
|
||||||
|
assert subprotocol == "proto2"
|
||||||
|
scope, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
@given(
|
||||||
|
request_path=http_strategies.http_path(),
|
||||||
|
request_params=http_strategies.query_params(),
|
||||||
|
request_headers=http_strategies.headers(),
|
||||||
|
)
|
||||||
|
@settings(max_examples=5, deadline=2000)
|
||||||
|
def test_http_bits(
|
||||||
|
self,
|
||||||
|
request_path,
|
||||||
|
request_params,
|
||||||
|
request_headers,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Tests that various HTTP-level bits (query string params, path, headers)
|
||||||
|
carry over into the scope.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
self.websocket_handshake(
|
||||||
|
test_app,
|
||||||
|
path=request_path,
|
||||||
|
params=request_params,
|
||||||
|
headers=request_headers,
|
||||||
|
)
|
||||||
|
# Validate the scope and messages we got
|
||||||
|
scope, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_scope(
|
||||||
|
scope,
|
||||||
|
path=request_path,
|
||||||
|
params=request_params,
|
||||||
|
headers=request_headers,
|
||||||
|
)
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
def test_text_frames(self):
|
||||||
|
"""
|
||||||
|
Tests we can send and receive text frames.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
# Connect
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
sock, _ = self.websocket_handshake(test_app)
|
||||||
|
_, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
# Prep frame for it to send
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.send",
|
||||||
|
"text": "here be dragons 🐉",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
# Send it a frame
|
||||||
|
self.websocket_send_frame(sock, "what is here? 🌍")
|
||||||
|
# Receive a frame and make sure it's correct
|
||||||
|
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
|
||||||
|
# Make sure it got our frame
|
||||||
|
_, messages = test_app.get_received()
|
||||||
|
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
|
||||||
|
|
||||||
|
def test_binary_frames(self):
|
||||||
|
"""
|
||||||
|
Tests we can send and receive binary frames with things that are very
|
||||||
|
much not valid UTF-8.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
# Connect
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
sock, _ = self.websocket_handshake(test_app)
|
||||||
|
_, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
# Prep frame for it to send
|
||||||
|
test_app.add_send_messages([
|
||||||
|
{
|
||||||
|
"type": "websocket.send",
|
||||||
|
"bytes": b"here be \xe2 bytes",
|
||||||
|
}
|
||||||
|
])
|
||||||
|
# Send it a frame
|
||||||
|
self.websocket_send_frame(sock, b"what is here? \xe2")
|
||||||
|
# Receive a frame and make sure it's correct
|
||||||
|
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
|
||||||
|
# Make sure it got our frame
|
||||||
|
_, messages = test_app.get_received()
|
||||||
|
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
|
Loading…
Reference in New Issue
Block a user