From 567c27504d4ecb4937aa9a1a20a166494ffad810 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Nov 2017 00:00:34 -0800 Subject: [PATCH] Add websocket tests to make sure everything important is covered. --- daphne/{test_utils.py => test_application.py} | 60 ++- daphne/tests/__init__.py | 0 daphne/tests/factories.py | 128 ------- daphne/tests/test_ws.py | 246 ------------- daphne/ws_protocol.py | 1 - tests/http_base.py | 341 +++++++++++------- tests/test_http_request.py | 2 +- tests/test_websocket.py | 239 ++++++++++++ 8 files changed, 485 insertions(+), 532 deletions(-) rename daphne/{test_utils.py => test_application.py} (56%) delete mode 100644 daphne/tests/__init__.py delete mode 100644 daphne/tests/factories.py delete mode 100644 daphne/tests/test_ws.py create mode 100644 tests/test_websocket.py diff --git a/daphne/test_utils.py b/daphne/test_application.py similarity index 56% rename from daphne/test_utils.py rename to daphne/test_application.py index ebfabbd..cbd2bdb 100644 --- a/daphne/test_utils.py +++ b/daphne/test_application.py @@ -1,3 +1,5 @@ +from concurrent.futures import CancelledError +import logging import os import pickle import tempfile @@ -17,21 +19,29 @@ class TestApplication: self.messages = [] async def __call__(self, send, receive): - # Load setup info - setup = self.load_setup() # Receive input and send output + logging.debug("test app coroutine alive") try: - for _ in range(setup["receive_messages"]): + while True: + # Receive a message and save it into the result store self.messages.append(await receive()) - for message in setup["response_messages"]: - await send(message) + logging.debug("test app received %r", self.messages[-1]) + self.save_result(self.scope, self.messages) + # See if there are any messages to send back + setup = self.load_setup() + self.delete_setup() + for message in setup["response_messages"]: + await send(message) + logging.debug("test app sent %r", message) except Exception as e: - self.save_exception(e) - else: - self.save_result() + if isinstance(e, CancelledError): + # Don't catch task-cancelled errors! + raise + else: + self.save_exception(e) @classmethod - def save_setup(cls, response_messages, receive_messages=1): + def save_setup(cls, response_messages): """ Stores setup information. """ @@ -39,7 +49,6 @@ class TestApplication: pickle.dump( { "response_messages": response_messages, - "receive_messages": receive_messages, }, fh, ) @@ -49,29 +58,34 @@ class TestApplication: """ Returns setup details. """ - with open(cls.setup_storage, "rb") as fh: - return pickle.load(fh) + try: + with open(cls.setup_storage, "rb") as fh: + return pickle.load(fh) + except FileNotFoundError: + return {"response_messages": []} - def save_result(self): + @classmethod + def save_result(cls, scope, messages): """ Saves details of what happened to the result storage. We could use pickle here, but that seems wrong, still, somehow. """ - with open(self.result_storage, "wb") as fh: + with open(cls.result_storage, "wb") as fh: pickle.dump( { - "scope": self.scope, - "messages": self.messages, + "scope": scope, + "messages": messages, }, fh, ) - def save_exception(self, exception): + @classmethod + def save_exception(cls, exception): """ Saves details of what happened to the result storage. We could use pickle here, but that seems wrong, still, somehow. """ - with open(self.result_storage, "wb") as fh: + with open(cls.result_storage, "wb") as fh: pickle.dump( { "exception": exception, @@ -88,14 +102,20 @@ class TestApplication: return pickle.load(fh) @classmethod - def clear_storage(cls): + def delete_setup(cls): """ - Clears storage files. + Clears setup storage files. """ try: os.unlink(cls.setup_storage) except OSError: pass + + @classmethod + def delete_result(cls): + """ + Clears result storage files. + """ try: os.unlink(cls.result_storage) except OSError: diff --git a/daphne/tests/__init__.py b/daphne/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/daphne/tests/factories.py b/daphne/tests/factories.py deleted file mode 100644 index 2b24f16..0000000 --- a/daphne/tests/factories.py +++ /dev/null @@ -1,128 +0,0 @@ -from __future__ import unicode_literals -import six -from six.moves.urllib import parse - -from asgiref.inmemory import ChannelLayer -from twisted.test import proto_helpers - -from daphne.http_protocol import HTTPFactory - - -def message_for_request(method, path, params=None, headers=None, body=None): - """ - Constructs a HTTP request according to the given parameters, runs - that through daphne and returns the emitted channel message. - """ - request = _build_request(method, path, params, headers, body) - message, factory, transport = _run_through_daphne(request, "http.request") - return message - - -def response_for_message(message): - """ - Returns the raw HTTP response that Daphne constructs when sending a reply - to a HTTP request. - - The current approach actually first builds a HTTP request (similar to - message_for_request) because we need a valid reply channel. I'm sure - this can be streamlined, but it works for now. - """ - request = _build_request("GET", "/") - request_message, factory, transport = _run_through_daphne(request, "http.request") - factory.dispatch_reply(request_message["reply_channel"], message) - return transport.value() - - -def _build_request(method, path, params=None, headers=None, body=None): - """ - Takes request parameters and returns a byte string of a valid HTTP/1.1 request. - - We really shouldn't manually build a HTTP request, and instead try to capture - what e.g. urllib or requests would do. But that is non-trivial, so meanwhile - we hope that our request building doesn't mask any errors. - - This code is messy, because urllib behaves rather different between Python 2 - and 3. Readability is further obstructed by the fact that Python 3.4 doesn't - support % formatting for bytes, so we need to concat everything. - If we run into more issues with this, the python-future library has a backport - of Python 3's urllib. - - :param method: ASCII string of HTTP method. - :param path: unicode string of URL path. - :param params: List of two-tuples of bytestrings, ready for consumption for - urlencode. Encode to utf8 if necessary. - :param headers: List of two-tuples ASCII strings of HTTP header, value. - :param body: ASCII string of request body. - - ASCII string is short for a unicode string containing only ASCII characters, - or a byte string with ASCII encoding. - """ - if headers is None: - headers = [] - else: - headers = headers[:] - - if six.PY3: - quoted_path = parse.quote(path) - if params: - quoted_path += "?" + parse.urlencode(params) - quoted_path = quoted_path.encode("ascii") - else: - quoted_path = parse.quote(path.encode("utf8")) - if params: - quoted_path += b"?" + parse.urlencode(params) - - request = method.encode("ascii") + b" " + quoted_path + b" HTTP/1.1\r\n" - for name, value in headers: - request += header_line(name, value) - - request += b"\r\n" - - if body: - request += body.encode("ascii") - - return request - - -def build_websocket_upgrade(path, params, headers): - ws_headers = [ - ("Host", "somewhere.com"), - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), - ("Sec-WebSocket-Protocol", "chat, superchat"), - ("Sec-WebSocket-Version", "13"), - ("Origin", "http://example.com") - ] - return _build_request("GET", path, params, headers=headers + ws_headers, body=None) - - -def header_line(name, value): - """ - Given a header name and value, returns the line to use in a HTTP request or response. - """ - return name.encode("ascii") + b": " + value.encode("ascii") + b"\r\n" - - -def _run_through_daphne(request, channel_name): - """ - Returns Daphne's channel message for a given request. - - This helper requires a fair bit of scaffolding and can certainly be improved, - but it works for now. - """ - channel_layer = ChannelLayer() - factory = HTTPFactory(channel_layer, send_channel="test!") - proto = factory.buildProtocol(("127.0.0.1", 0)) - transport = proto_helpers.StringTransport() - proto.makeConnection(transport) - proto.dataReceived(request) - _, message = channel_layer.receive([channel_name]) - return message, factory, transport - - -def content_length_header(body): - """ - Returns an appropriate Content-Length HTTP header for a given body. - """ - return "Content-Length", six.text_type(len(body)) diff --git a/daphne/tests/test_ws.py b/daphne/tests/test_ws.py deleted file mode 100644 index 311d4e9..0000000 --- a/daphne/tests/test_ws.py +++ /dev/null @@ -1,246 +0,0 @@ -# coding: utf8 -from __future__ import unicode_literals - -from hypothesis import assume, given, strategies, settings -from twisted.test import proto_helpers - -from asgiref.inmemory import ChannelLayer -from daphne.http_protocol import HTTPFactory -from daphne.tests import http_strategies, testcases, factories - - -class WebSocketConnection(object): - """ - Helper class that makes it easier to test Dahpne's WebSocket support. - """ - - def __init__(self): - self.last_message = None - - self.channel_layer = ChannelLayer() - self.factory = HTTPFactory(self.channel_layer, send_channel="test!") - self.proto = self.factory.buildProtocol(("127.0.0.1", 0)) - self.transport = proto_helpers.StringTransport() - self.proto.makeConnection(self.transport) - - def receive(self, request): - """ - Low-level method to let Daphne handle HTTP/WebSocket data - """ - self.proto.dataReceived(request) - _, self.last_message = self.channel_layer.receive(["websocket.connect"]) - return self.last_message - - def send(self, content): - """ - Method to respond with a channel message - """ - if self.last_message is None: - # Auto-connect for convenience. - self.connect() - self.factory.dispatch_reply(self.last_message["reply_channel"], content) - response = self.transport.value() - self.transport.clear() - return response - - def connect(self, path="/", params=None, headers=None): - """ - High-level method to perform the WebSocket handshake - """ - request = factories.build_websocket_upgrade(path, params, headers or []) - message = self.receive(request) - return message - - -class TestHandshake(testcases.ASGIWebSocketTestCase): - """ - Tests for the WebSocket handshake - """ - - def test_minimal(self): - message = WebSocketConnection().connect() - self.assert_valid_websocket_connect_message(message) - - @given( - path=http_strategies.http_path(), - params=http_strategies.query_params(), - headers=http_strategies.headers(), - ) - @settings(perform_health_check=False) - def test_connection(self, path, params, headers): - message = WebSocketConnection().connect(path, params, headers) - self.assert_valid_websocket_connect_message(message, path, params, headers) - - -class TestSendCloseAccept(testcases.ASGIWebSocketTestCase): - """ - Tests that, essentially, try to translate the send/close/accept section of the spec into code. - """ - - def test_empty_accept(self): - response = WebSocketConnection().send({"accept": True}) - self.assert_websocket_upgrade(response) - - @given(text=http_strategies.http_body()) - def test_accept_and_text(self, text): - response = WebSocketConnection().send({"accept": True, "text": text}) - self.assert_websocket_upgrade(response, text.encode("ascii")) - - @given(data=http_strategies.binary_payload()) - def test_accept_and_bytes(self, data): - response = WebSocketConnection().send({"accept": True, "bytes": data}) - self.assert_websocket_upgrade(response, data) - - def test_accept_false(self): - response = WebSocketConnection().send({"accept": False}) - self.assert_websocket_denied(response) - - def test_accept_false_with_text(self): - """ - Tests that even if text is given, the connection is denied. - - We can't easily use Hypothesis to generate data for this test because it's - hard to detect absence of the body if e.g. Hypothesis would generate a 'GET' - """ - text = "foobar" - response = WebSocketConnection().send({"accept": False, "text": text}) - self.assert_websocket_denied(response) - self.assertNotIn(text.encode("ascii"), response) - - def test_accept_false_with_bytes(self): - """ - Tests that even if data is given, the connection is denied. - - We can't easily use Hypothesis to generate data for this test because it's - hard to detect absence of the body if e.g. Hypothesis would generate a 'GET' - """ - data = b"foobar" - response = WebSocketConnection().send({"accept": False, "bytes": data}) - self.assert_websocket_denied(response) - self.assertNotIn(data, response) - - @given(text=http_strategies.http_body()) - def test_just_text(self, text): - assume(len(text) > 0) - # If content is sent, accept=True is implied. - response = WebSocketConnection().send({"text": text}) - self.assert_websocket_upgrade(response, text.encode("ascii")) - - @given(data=http_strategies.binary_payload()) - def test_just_bytes(self, data): - assume(len(data) > 0) - # If content is sent, accept=True is implied. - response = WebSocketConnection().send({"bytes": data}) - self.assert_websocket_upgrade(response, data) - - def test_close_boolean(self): - response = WebSocketConnection().send({"close": True}) - self.assert_websocket_denied(response) - - @given(number=strategies.integers(min_value=1)) - def test_close_integer(self, number): - response = WebSocketConnection().send({"close": number}) - self.assert_websocket_denied(response) - - @given(text=http_strategies.http_body()) - def test_close_with_text(self, text): - assume(len(text) > 0) - response = WebSocketConnection().send({"close": True, "text": text}) - self.assert_websocket_upgrade(response, text.encode("ascii"), expect_close=True) - - @given(data=http_strategies.binary_payload()) - def test_close_with_data(self, data): - assume(len(data) > 0) - response = WebSocketConnection().send({"close": True, "bytes": data}) - self.assert_websocket_upgrade(response, data, expect_close=True) - - -class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase): - """ - Tests that the WebSocket protocol class correctly generates and parses messages. - """ - - def setUp(self): - self.connection = WebSocketConnection() - - def test_basic(self): - # Send a simple request to the protocol and get the resulting message off - # of the channel layer. - message = self.connection.receive( - b"GET /chat HTTP/1.1\r\n" - b"Host: somewhere.com\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" - b"Sec-WebSocket-Protocol: chat, superchat\r\n" - b"Sec-WebSocket-Version: 13\r\n" - b"Origin: http://example.com\r\n" - b"\r\n" - ) - self.assertEqual(message["path"], "/chat") - self.assertEqual(message["query_string"], b"") - self.assertEqual( - sorted(message["headers"]), - [(b"connection", b"Upgrade"), - (b"host", b"somewhere.com"), - (b"origin", b"http://example.com"), - (b"sec-websocket-key", b"x3JJHMbDL1EzLkh9GBhXDw=="), - (b"sec-websocket-protocol", b"chat, superchat"), - (b"sec-websocket-version", b"13"), - (b"upgrade", b"websocket")] - ) - self.assert_valid_websocket_connect_message(message, "/chat") - - # Accept the connection - response = self.connection.send({"accept": True}) - self.assert_websocket_upgrade(response) - - # Send some text - response = self.connection.send({"text": "Hello World!"}) - self.assertEqual(response, b"\x81\x0cHello World!") - - # Send some bytes - response = self.connection.send({"bytes": b"\xaa\xbb\xcc\xdd"}) - self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd") - - # Close the connection - response = self.connection.send({"close": True}) - self.assertEqual(response, b"\x88\x02\x03\xe8") - - def test_connection_with_file_origin_is_accepted(self): - message = self.connection.receive( - b"GET /chat HTTP/1.1\r\n" - b"Host: somewhere.com\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" - b"Sec-WebSocket-Protocol: chat, superchat\r\n" - b"Sec-WebSocket-Version: 13\r\n" - b"Origin: file://\r\n" - b"\r\n" - ) - self.assertIn((b"origin", b"file://"), message["headers"]) - self.assert_valid_websocket_connect_message(message, "/chat") - - # Accept the connection - response = self.connection.send({"accept": True}) - self.assert_websocket_upgrade(response) - - def test_connection_with_no_origin_is_accepted(self): - message = self.connection.receive( - b"GET /chat HTTP/1.1\r\n" - b"Host: somewhere.com\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" - b"Sec-WebSocket-Protocol: chat, superchat\r\n" - b"Sec-WebSocket-Version: 13\r\n" - b"\r\n" - ) - - self.assertNotIn(b"origin", [header_tuple[0] for header_tuple in message["headers"]]) - self.assert_valid_websocket_connect_message(message, "/chat") - - # Accept the connection - response = self.connection.send({"accept": True}) - self.assert_websocket_upgrade(response) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 60d53be..36b38aa 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -73,7 +73,6 @@ class WebSocketProtocol(WebSocketServerProtocol): "client": self.client_addr, "server": self.server_addr, "subprotocols": subprotocols, - "order": 0, }) except: # Exceptions here are not displayed right, just 500. diff --git a/tests/http_base.py b/tests/http_base.py index 95990d5..13267a6 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -1,22 +1,26 @@ -from urllib import parse from http.client import HTTPConnection +from urllib import parse import socket +import struct import subprocess import time import unittest -from daphne.test_utils import TestApplication +from daphne.test_application import TestApplication -class DaphneTestCase(unittest.TestCase): +class DaphneTestingInstance: """ - Base class for Daphne integration test cases. + Launches an instance of Daphne to test against, with an application + object you can read messages from and feed messages to. - Boots up a copy of Daphne on a test port and sends it a request, and - retrieves the response. Uses a custom ASGI application and temporary files - to store/retrieve the request/response messages. + Works as a context manager. """ + def __init__(self, xff=False): + self.xff = xff + self.host = "127.0.0.1" + def port_in_use(self, port): """ Tests if a port is in use on the local machine. @@ -34,39 +38,90 @@ class DaphneTestCase(unittest.TestCase): finally: s.close() - def run_daphne(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): + def find_free_port(self): + """ + Finds an unused port to test stuff on + """ + for i in range(11200, 11300): + if not self.port_in_use(i): + return i + raise RuntimeError("Cannot find a free port to test on") + + def __enter__(self): + # Clear result storage + TestApplication.delete_setup() + TestApplication.delete_result() + # Find a port to listen on + self.port = self.find_free_port() + daphne_args = ["daphne", "-p", str(self.port), "-v", "0"] + # Optionally enable X-Forwarded-For support. + if self.xff: + daphne_args += ["--proxy-headers"] + # Start up process and make sure it begins listening. + self.process = subprocess.Popen(daphne_args + ["daphne.test_application:TestApplication"]) + for _ in range(100): + time.sleep(0.1) + if self.port_in_use(self.port): + return self + # Daphne didn't start up. Sadface. + self.process.terminate() + raise RuntimeError("Daphne never came up.") + + def __exit__(self, exc_type, exc_value, traceback): + # Shut down the process + self.process.terminate() + del self.process + + def get_received(self): + """ + Returns the scope and messages the test application has received + so far. Note you'll get all messages since scope start, not just any + new ones since the last call. + + Also checks for any exceptions in the application. If there are, + raises them. + """ + try: + inner_result = TestApplication.load_result() + except FileNotFoundError: + raise ValueError("No results available yet.") + # Check for exception + if "exception" in inner_result: + raise inner_result["exception"] + return inner_result["scope"], inner_result["messages"] + + def add_send_messages(self, messages): + """ + Adds messages for the application to send back. + The next time it receives an incoming message, it will reply with these. + """ + TestApplication.save_setup( + response_messages=messages, + ) + + +class DaphneTestCase(unittest.TestCase): + """ + Base class for Daphne integration test cases. + + Boots up a copy of Daphne on a test port and sends it a request, and + retrieves the response. Uses a custom ASGI application and temporary files + to store/retrieve the request/response messages. + """ + + ### Plain HTTP helpers + + def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): """ Runs Daphne with the given request callback (given the base URL) and response messages. """ - # Store setup info - TestApplication.clear_storage() - TestApplication.save_setup( - response_messages=responses, - ) - # Find a free port - for i in range(11200, 11300): - if not self.port_in_use(i): - port = i - break - else: - raise RuntimeError("Cannot find a free port to test on") - # Launch daphne on that port - daphne_args = ["daphne", "-p", str(port), "-v", "0"] - if xff: - # Optionally enable X-Forwarded-For support. - daphne_args += ["--proxy-headers"] - process = subprocess.Popen(daphne_args + ["daphne.test_utils:TestApplication"]) - try: - for _ in range(100): - time.sleep(0.1) - if self.port_in_use(port): - break - else: - raise RuntimeError("Daphne never came up.") + with DaphneTestingInstance(xff=xff) as test_app: + # Add the response messages + test_app.add_send_messages(responses) # Send it the request. We have to do this the long way to allow # duplicate headers. - conn = HTTPConnection("127.0.0.1", port, timeout=timeout) + conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout) # Make sure path is urlquoted and add any params path = parse.quote(path) if params: @@ -86,29 +141,17 @@ class DaphneTestCase(unittest.TestCase): response = conn.getresponse() except socket.timeout: # See if they left an exception for us to load - try: - exception_result = TestApplication.load_result() - except OSError: - raise RuntimeError("Daphne timed out handling request, no result file") - else: - if "exception" in exception_result: - raise exception_result["exception"] - else: - raise RuntimeError("Daphne timed out handling request, no exception found: %r" % exception_result) - finally: - # Shut down daphne - process.terminate() - # Load the information - inner_result = TestApplication.load_result() - # Return the inner result and the response - return inner_result, response + test_app.get_received() + raise RuntimeError("Daphne timed out handling request, no exception found.") + # Return scope, messages, response + return test_app.get_received() + (response, ) def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False): """ Convenience method for just testing request handling. Returns (scope, messages) """ - inner_result, _ = self.run_daphne( + scope, messages, _ = self.run_daphne_http( method=method, path=path, params=params, @@ -117,14 +160,14 @@ class DaphneTestCase(unittest.TestCase): xff=xff, responses=[{"type": "http.response", "status": 200, "content": b"OK"}], ) - return inner_result["scope"], inner_result["messages"] + return scope, messages def run_daphne_response(self, response_messages): """ Convenience method for just testing response handling. Returns (scope, messages) """ - _, response = self.run_daphne( + _, _, response = self.run_daphne_http( method="GET", path="/", params={}, @@ -133,11 +176,119 @@ class DaphneTestCase(unittest.TestCase): ) return response + ### WebSocket helpers + + def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1): + """ + Runs a WebSocket handshake negotiation and returns the raw socket + object & the selected subprotocol. + + You'll need to inject an accept or reject message before this + to let it complete. + """ + # Send it the request. We have to do this the long way to allow + # duplicate headers. + conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout) + # Make sure path is urlquoted and add any params + path = parse.quote(path) + if params: + path += "?" + parse.urlencode(params, doseq=True) + conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True) + # Do WebSocket handshake headers + any other headers + if headers is None: + headers = [] + headers.extend([ + ("Host", "example.com"), + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), + ("Sec-WebSocket-Version", "13"), + ("Origin", "http://example.com") + ]) + if subprotocols: + headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) + if headers: + for header_name, header_value in headers: + conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) + conn.endheaders() + # Read out the response + try: + response = conn.getresponse() + except socket.timeout: + # See if they left an exception for us to load + test_app.get_received() + raise RuntimeError("Daphne timed out handling request, no exception found.") + # Check we got a good response code + if response.status != 101: + raise RuntimeError("WebSocket upgrade did not result in status code 101") + # Prepare headers for subprotocol searching + response_headers = dict( + (n.lower(), v) + for n, v in response.getheaders() + ) + response.read() + assert not response.closed + # Return the raw socket and any subprotocol + return conn.sock, response_headers.get("sec-websocket-protocol", None) + + def websocket_send_frame(self, sock, value): + """ + Sends a WebSocket text or binary frame. Cannot handle long frames. + """ + # Header and text opcode + if isinstance(value, str): + frame = b"\x81" + value = value.encode("utf8") + else: + frame = b"\x82" + # Length plus masking signal bit + frame += struct.pack("!B", len(value) | 0b10000000) + # Mask badly + frame += b"\0\0\0\0" + # Payload + frame += value + print("sending %r" % frame) + sock.sendall(frame) + + def receive_from_socket(self, sock, length, timeout=1): + """ + Receives the given amount of bytes from the socket, or times out. + """ + buf = b"" + started = time.time() + while len(buf) < length: + buf += sock.recv(length - len(buf)) + time.sleep(0.001) + if time.time() - started > timeout: + raise ValueError("Timed out reading from socket") + return buf + + def websocket_receive_frame(self, sock): + """ + Receives a WebSocket frame. Cannot handle long frames. + """ + # Read header byte + # TODO: Proper receive buffer handling + opcode = self.receive_from_socket(sock, 1) + if opcode in [b"\x81", b"\x82"]: + # Read length + length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0] + # Read payload + payload = self.receive_from_socket(sock, length) + if opcode == b"\x81": + payload = payload.decode("utf8") + return payload + else: + raise ValueError("Unknown websocket opcode: %r" % opcode) + + ### Assertions and test management + def tearDown(self): """ Ensures any storage files are cleared. """ - TestApplication.clear_storage() + TestApplication.delete_setup() + TestApplication.delete_result() def assert_is_ip_address(self, address): """ @@ -179,85 +330,3 @@ class DaphneTestCase(unittest.TestCase): self.assertIsInstance(address, str) self.assert_is_ip_address(address) self.assertIsInstance(port, int) - - - -# class ASGIWebSocketTestCase(ASGITestCaseBase): -# """ -# Test case with helpers for verifying WebSocket channel messages -# """ - -# def assert_websocket_upgrade(self, response, body=b"", expect_close=False): -# self.assertIn(b"HTTP/1.1 101 Switching Protocols", response) -# self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) -# self.assertIn(body, response) -# self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8")) - -# def assert_websocket_denied(self, response): -# self.assertIn(b"HTTP/1.1 403", response) - -# def assert_valid_websocket_connect_message( -# self, channel_message, request_path="/", request_params=None, request_headers=None): -# """ -# Asserts that a given channel message conforms to the HTTP request section of the ASGI spec. -# """ - -# self.assertTrue(channel_message) - -# self.assert_presence_of_message_keys( -# channel_message.keys(), -# {"reply_channel", "path", "headers", "order"}, -# {"scheme", "query_string", "root_path", "client", "server"}) - -# # == Assertions about required channel_message fields == -# self.assert_valid_reply_channel(channel_message["reply_channel"]) -# self.assert_valid_path(channel_message["path"], request_path) - -# order = channel_message["order"] -# self.assertIsInstance(order, int) -# self.assertEqual(order, 0) - -# # Ordering of header names is not important, but the order of values for a header -# # name is. To assert whether that order is kept, we transform the request -# # headers and the channel message headers into a set -# # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal. -# # Note that unlike for HTTP, Daphne never gives out individual header values; instead we -# # get one string per header field with values separated by comma. -# transformed_request_headers = defaultdict(list) -# for name, value in (request_headers or []): -# expected_name = name.lower().strip().encode("ascii") -# expected_value = value.strip().encode("ascii") -# transformed_request_headers[expected_name].append(expected_value) -# final_request_headers = { -# (name, b",".join(value)) for name, value in transformed_request_headers.items() -# } - -# # Websockets carry a lot of additional header fields, so instead of verifying that -# # headers look exactly like expected, we just check that the expected header fields -# # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed -# # and not tested for. -# assert final_request_headers.issubset(set(channel_message["headers"])) - -# # == Assertions about optional channel_message fields == -# scheme = channel_message.get("scheme") -# if scheme: -# self.assertIsInstance(scheme, six.text_type) -# self.assertIn(scheme, ["ws", "wss"]) - -# query_string = channel_message.get("query_string") -# if query_string: -# # Assert that query_string is a byte string and still url encoded -# self.assertIsInstance(query_string, six.binary_type) -# self.assertEqual(query_string, parse.urlencode(request_params or []).encode("ascii")) - -# root_path = channel_message.get("root_path") -# if root_path is not None: -# self.assertIsInstance(root_path, six.text_type) - -# client = channel_message.get("client") -# if client is not None: -# self.assert_valid_address_and_port(channel_message["client"]) - -# server = channel_message.get("server") -# if server is not None: -# self.assert_valid_address_and_port(channel_message["server"]) diff --git a/tests/test_http_request.py b/tests/test_http_request.py index 0b79a87..5dea29b 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -207,7 +207,7 @@ class TestHTTPRequestSpec(DaphneTestCase): self.assert_valid_http_request_message(messages[0], body=b"") # Note that Daphne returns a list of tuples here, which is fine, because the spec # asks to treat them interchangeably. - assert scope["headers"] == [[b"mycustomheader", b"foobar"]] + assert [list(x) for x in scope["headers"]] == [[b"mycustomheader", b"foobar"]] @given(daphne_path=http_strategies.http_path()) @settings(max_examples=5, deadline=2000) diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 0000000..bf2af70 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,239 @@ +# coding: utf8 + +import collections +from urllib import parse + +from hypothesis import given, settings + +import http_strategies +from http_base import DaphneTestCase, DaphneTestingInstance + + +class TestWebsocket(DaphneTestCase): + """ + Tests which try to pour the HTTP request section of the ASGI spec into code. + The heavy lifting is done by the assert_valid_http_request_message function, + the tests mostly serve to wire up hypothesis so that it exercise it's power to find + edge cases. + """ + + def assert_valid_websocket_scope( + self, + scope, + path="/", + params=None, + headers=None, + scheme=None, + subprotocols=None, + ): + """ + Checks that the passed scope is a valid ASGI HTTP scope regarding types + and some urlencoding things. + """ + # Check overall keys + self.assert_key_sets( + required_keys={"type", "path", "query_string", "headers"}, + optional_keys={"scheme", "root_path", "client", "server", "subprotocols"}, + actual_keys=scope.keys(), + ) + # Check that it is the right type + self.assertEqual(scope["type"], "websocket") + # Path + self.assert_valid_path(scope["path"], path) + # Scheme + self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"]) + if scheme: + self.assertEqual(scheme, scope["scheme"]) + # Query string (byte string and still url encoded) + query_string = scope["query_string"] + self.assertIsInstance(query_string, bytes) + if params: + self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + # Ordering of header names is not important, but the order of values for a header + # name is. To assert whether that order is kept, we transform both the request + # headers and the channel message headers into a dictionary + # {name: [value1, value2, ...]} and check if they're equal. + transformed_scope_headers = collections.defaultdict(list) + for name, value in scope["headers"]: + transformed_scope_headers[name].append(value) + transformed_request_headers = collections.defaultdict(list) + for name, value in (headers or []): + expected_name = name.lower().strip().encode("ascii") + expected_value = value.strip().encode("ascii") + transformed_request_headers[expected_name].append(expected_value) + for name, value in transformed_request_headers.items(): + self.assertIn(name, transformed_scope_headers) + self.assertEqual(value, transformed_scope_headers[name]) + # Root path + self.assertIsInstance(scope.get("root_path", ""), str) + # Client and server addresses + client = scope.get("client") + if client is not None: + self.assert_valid_address_and_port(client) + server = scope.get("server") + if server is not None: + self.assert_valid_address_and_port(server) + # Subprotocols + scope_subprotocols = scope.get("subprotocols", []) + if scope_subprotocols: + assert all(isinstance(x, str) for x in scope_subprotocols) + if subprotocols: + assert sorted(scope_subprotocols) == sorted(subprotocols) + + def assert_valid_websocket_connect_message(self, message): + """ + Asserts that a message is a valid http.request message + """ + # Check overall keys + self.assert_key_sets( + required_keys={"type"}, + optional_keys=set(), + actual_keys=message.keys(), + ) + # Check that it is the right type + self.assertEqual(message["type"], "websocket.connect") + + def test_accept(self): + """ + Tests we can open and accept a socket. + """ + with DaphneTestingInstance() as test_app: + test_app.add_send_messages([ + { + "type": "websocket.accept", + } + ]) + self.websocket_handshake(test_app) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_scope(scope) + self.assert_valid_websocket_connect_message(messages[0]) + + def test_reject(self): + """ + Tests we can reject a socket and it won't complete the handshake. + """ + with DaphneTestingInstance() as test_app: + test_app.add_send_messages([ + { + "type": "websocket.close", + } + ]) + with self.assertRaises(RuntimeError): + self.websocket_handshake(test_app) + + def test_subprotocols(self): + """ + Tests that we can ask for subprotocols and then select one. + """ + subprotocols = ["proto1", "proto2"] + with DaphneTestingInstance() as test_app: + test_app.add_send_messages([ + { + "type": "websocket.accept", + "subprotocol": "proto2", + } + ]) + _, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols) + # Validate the scope and messages we got + assert subprotocol == "proto2" + scope, messages = test_app.get_received() + self.assert_valid_websocket_scope(scope, subprotocols=subprotocols) + self.assert_valid_websocket_connect_message(messages[0]) + + @given( + request_path=http_strategies.http_path(), + request_params=http_strategies.query_params(), + request_headers=http_strategies.headers(), + ) + @settings(max_examples=5, deadline=2000) + def test_http_bits( + self, + request_path, + request_params, + request_headers, + ): + """ + Tests that various HTTP-level bits (query string params, path, headers) + carry over into the scope. + """ + with DaphneTestingInstance() as test_app: + test_app.add_send_messages([ + { + "type": "websocket.accept", + } + ]) + self.websocket_handshake( + test_app, + path=request_path, + params=request_params, + headers=request_headers, + ) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_scope( + scope, + path=request_path, + params=request_params, + headers=request_headers, + ) + self.assert_valid_websocket_connect_message(messages[0]) + + def test_text_frames(self): + """ + Tests we can send and receive text frames. + """ + with DaphneTestingInstance() as test_app: + # Connect + test_app.add_send_messages([ + { + "type": "websocket.accept", + } + ]) + sock, _ = self.websocket_handshake(test_app) + _, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + # Prep frame for it to send + test_app.add_send_messages([ + { + "type": "websocket.send", + "text": "here be dragons 🐉", + } + ]) + # Send it a frame + self.websocket_send_frame(sock, "what is here? 🌍") + # Receive a frame and make sure it's correct + assert self.websocket_receive_frame(sock) == "here be dragons 🐉" + # Make sure it got our frame + _, messages = test_app.get_received() + assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"} + + def test_binary_frames(self): + """ + Tests we can send and receive binary frames with things that are very + much not valid UTF-8. + """ + with DaphneTestingInstance() as test_app: + # Connect + test_app.add_send_messages([ + { + "type": "websocket.accept", + } + ]) + sock, _ = self.websocket_handshake(test_app) + _, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + # Prep frame for it to send + test_app.add_send_messages([ + { + "type": "websocket.send", + "bytes": b"here be \xe2 bytes", + } + ]) + # Send it a frame + self.websocket_send_frame(sock, b"what is here? \xe2") + # Receive a frame and make sure it's correct + assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes" + # Make sure it got our frame + _, messages = test_app.get_received() + assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}