From 2bcec3fe94db1c78f1df1f65f5273032ac0cd48b Mon Sep 17 00:00:00 2001 From: Maik Hoepfel Date: Fri, 28 Apr 2017 23:45:07 +0200 Subject: [PATCH] Websockets test and unicode fix for Python 2 (#111) * Python 2 fix for host address This is a copy of https://github.com/django/daphne/pull/91/commits/57051a48cd485c2dbb4a4c09d8c47f294ba75f06 for the Websocket protocol. In Python 2, Twisted returns a byte string for the host address, while the spec requires a unicode string. A simple cast gives us consistency. * Test suite for websocket tests This commit * introduces some new helpers to test the Websocket protocol * renames the old ASGITestCase class to ASGIHTTPTestCase, and introduces a test case for testing Websockets * moves some helper methods that are shared between HTTP and Websockets into a mutual base class * uses the new helpers to simplfiy the existing tests * and adds a couple new tests. --- daphne/tests/factories.py | 13 ++ daphne/tests/http_strategies.py | 13 +- daphne/tests/test_http_request.py | 2 +- daphne/tests/test_http_response.py | 2 +- daphne/tests/test_ws.py | 239 ++++++++++++++++++++--------- daphne/tests/testcases.py | 153 ++++++++++++++---- daphne/ws_protocol.py | 8 +- 7 files changed, 321 insertions(+), 109 deletions(-) diff --git a/daphne/tests/factories.py b/daphne/tests/factories.py index 11ad298..8f260f1 100644 --- a/daphne/tests/factories.py +++ b/daphne/tests/factories.py @@ -84,6 +84,19 @@ def _build_request(method, path, params=None, headers=None, body=None): return request +def build_websocket_upgrade(path, params, headers): + ws_headers = [ + ('Host', 'somewhere.com'), + ('Upgrade', 'websocket'), + ('Connection', 'Upgrade'), + ('Sec-WebSocket-Key', 'x3JJHMbDL1EzLkh9GBhXDw=='), + ('Sec-WebSocket-Protocol', 'chat, superchat'), + ('Sec-WebSocket-Version', '13'), + ('Origin', 'http://example.com') + ] + return _build_request('GET', path, params, headers=headers + ws_headers, body=None) + + def header_line(name, value): """ Given a header name and value, returns the line to use in a HTTP request or response. diff --git a/daphne/tests/http_strategies.py b/daphne/tests/http_strategies.py index ec9b9cd..85f984b 100644 --- a/daphne/tests/http_strategies.py +++ b/daphne/tests/http_strategies.py @@ -17,12 +17,17 @@ def http_method(): return strategies.sampled_from(HTTP_METHODS) +def _http_path_portion(): + alphabet = string.ascii_letters + string.digits + '-._~' + return strategies.text(min_size=1, average_size=10, max_size=128, alphabet=alphabet) + + def http_path(): """ Returns a URL path (not encoded). """ - alphabet = string.ascii_letters + string.digits + '-._~/' - return strategies.text(min_size=0, max_size=255, alphabet=alphabet).map(lambda s: '/' + s) + return strategies.lists( + _http_path_portion(), min_size=0, max_size=10).map(lambda s: '/' + '/'.join(s)) def http_body(): @@ -33,6 +38,10 @@ def http_body(): return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500) +def binary_payload(): + return strategies.binary(min_size=0, average_size=600, max_size=1500) + + def valid_bidi(value): """ Rejects strings which nonsensical Unicode text direction flags. diff --git a/daphne/tests/test_http_request.py b/daphne/tests/test_http_request.py index e10553e..ee38ab9 100644 --- a/daphne/tests/test_http_request.py +++ b/daphne/tests/test_http_request.py @@ -16,7 +16,7 @@ from daphne.tests import testcases, http_strategies from daphne.tests.factories import message_for_request, content_length_header -class TestHTTPRequestSpec(testcases.ASGITestCase): +class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase): """ Tests which try to pour the HTTP request section of the ASGI spec into code. The heavy lifting is done by the assert_valid_http_request_message function, diff --git a/daphne/tests/test_http_response.py b/daphne/tests/test_http_response.py index d1433d2..a650deb 100644 --- a/daphne/tests/test_http_response.py +++ b/daphne/tests/test_http_response.py @@ -14,7 +14,7 @@ from daphne.http_protocol import HTTPFactory from . import factories, http_strategies, testcases -class TestHTTPResponseSpec(testcases.ASGITestCase): +class TestHTTPResponseSpec(testcases.ASGIHTTPTestCase): def test_minimal_response(self): """ diff --git a/daphne/tests/test_ws.py b/daphne/tests/test_ws.py index 6acb529..c611a95 100644 --- a/daphne/tests/test_ws.py +++ b/daphne/tests/test_ws.py @@ -1,27 +1,171 @@ # coding: utf8 from __future__ import unicode_literals -from unittest import TestCase -from asgiref.inmemory import ChannelLayer + +from hypothesis import assume, given, strategies from twisted.test import proto_helpers +from asgiref.inmemory import ChannelLayer from daphne.http_protocol import HTTPFactory +from daphne.tests import http_strategies, testcases, factories -class TestWebSocketProtocol(TestCase): +class WebSocketConnection(object): """ - Tests that the WebSocket protocol class correcly generates and parses messages. + Helper class that makes it easier to test Dahpne's WebSocket support. """ - def setUp(self): + def __init__(self): + self.last_message = None + self.channel_layer = ChannelLayer() self.factory = HTTPFactory(self.channel_layer, send_channel="test!") self.proto = self.factory.buildProtocol(('127.0.0.1', 0)) - self.tr = proto_helpers.StringTransport() - self.proto.makeConnection(self.tr) + self.transport = proto_helpers.StringTransport() + self.proto.makeConnection(self.transport) + + def receive(self, request): + """ + Low-level method to let Daphne handle HTTP/WebSocket data + """ + self.proto.dataReceived(request) + _, self.last_message = self.channel_layer.receive(['websocket.connect']) + return self.last_message + + def send(self, content): + """ + Method to respond with a channel message + """ + if self.last_message is None: + # Auto-connect for convenience. + self.connect() + self.factory.dispatch_reply(self.last_message['reply_channel'], content) + response = self.transport.value() + self.transport.clear() + return response + + def connect(self, path='/', params=None, headers=None): + """ + High-level method to perform the WebSocket handshake + """ + request = factories.build_websocket_upgrade(path, params, headers or []) + message = self.receive(request) + return message + + +class TestHandshake(testcases.ASGIWebSocketTestCase): + """ + Tests for the WebSocket handshake + """ + + def test_minimal(self): + message = WebSocketConnection().connect() + self.assert_valid_websocket_connect_message(message) + + @given( + path=http_strategies.http_path(), + params=http_strategies.query_params(), + headers=http_strategies.headers(), + ) + def test_connection(self, path, params, headers): + message = WebSocketConnection().connect(path, params, headers) + self.assert_valid_websocket_connect_message(message, path, params, headers) + + +class TestSendCloseAccept(testcases.ASGIWebSocketTestCase): + """ + Tests that, essentially, try to translate the send/close/accept section of the spec into code. + """ + + def test_empty_accept(self): + response = WebSocketConnection().send({'accept': True}) + self.assert_websocket_upgrade(response) + + @given(text=http_strategies.http_body()) + def test_accept_and_text(self, text): + response = WebSocketConnection().send({'accept': True, 'text': text}) + self.assert_websocket_upgrade(response, text.encode('ascii')) + + @given(data=http_strategies.binary_payload()) + def test_accept_and_bytes(self, data): + response = WebSocketConnection().send({'accept': True, 'bytes': data}) + self.assert_websocket_upgrade(response, data) + + def test_accept_false(self): + response = WebSocketConnection().send({'accept': False}) + self.assert_websocket_denied(response) + + def test_accept_false_with_text(self): + """ + Tests that even if text is given, the connection is denied. + + We can't easily use Hypothesis to generate data for this test because it's + hard to detect absence of the body if e.g. Hypothesis would generate a 'GET' + """ + text = 'foobar' + response = WebSocketConnection().send({'accept': False, 'text': text}) + self.assert_websocket_denied(response) + self.assertNotIn(text.encode('ascii'), response) + + def test_accept_false_with_bytes(self): + """ + Tests that even if data is given, the connection is denied. + + We can't easily use Hypothesis to generate data for this test because it's + hard to detect absence of the body if e.g. Hypothesis would generate a 'GET' + """ + data = b'foobar' + response = WebSocketConnection().send({'accept': False, 'bytes': data}) + self.assert_websocket_denied(response) + self.assertNotIn(data, response) + + @given(text=http_strategies.http_body()) + def test_just_text(self, text): + assume(len(text) > 0) + # If content is sent, accept=True is implied. + response = WebSocketConnection().send({'text': text}) + self.assert_websocket_upgrade(response, text.encode('ascii')) + + @given(data=http_strategies.binary_payload()) + def test_just_bytes(self, data): + assume(len(data) > 0) + # If content is sent, accept=True is implied. + response = WebSocketConnection().send({'bytes': data}) + self.assert_websocket_upgrade(response, data) + + def test_close_boolean(self): + response = WebSocketConnection().send({'close': True}) + self.assert_websocket_denied(response) + + @given(number=strategies.integers(min_value=1)) + def test_close_integer(self, number): + response = WebSocketConnection().send({'close': number}) + self.assert_websocket_denied(response) + + @given(text=http_strategies.http_body()) + def test_close_with_text(self, text): + assume(len(text) > 0) + response = WebSocketConnection().send({'close': True, 'text': text}) + self.assert_websocket_upgrade(response, text.encode('ascii'), expect_close=True) + + @given(data=http_strategies.binary_payload()) + def test_close_with_data(self, data): + assume(len(data) > 0) + response = WebSocketConnection().send({'close': True, 'bytes': data}) + self.assert_websocket_upgrade(response, data, expect_close=True) + + +class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase): + """ + Tests that the WebSocket protocol class correctly generates and parses messages. + """ + + def setUp(self): + self.connection = WebSocketConnection() def test_basic(self): - # Send a simple request to the protocol - self.proto.dataReceived( + # Send a simple request to the protocol and get the resulting message off + # of the channel layer. + message = self.connection.receive( b"GET /chat HTTP/1.1\r\n" b"Host: somewhere.com\r\n" b"Upgrade: websocket\r\n" @@ -32,8 +176,6 @@ class TestWebSocketProtocol(TestCase): b"Origin: http://example.com\r\n" b"\r\n" ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["websocket.connect"]) self.assertEqual(message['path'], "/chat") self.assertEqual(message['query_string'], "") self.assertEqual( @@ -46,53 +188,26 @@ class TestWebSocketProtocol(TestCase): (b'sec-websocket-version', b'13'), (b'upgrade', b'websocket')] ) - self.assertTrue(message['reply_channel'].startswith("test!")) + self.assert_valid_websocket_connect_message(message, '/chat') # Accept the connection - self.factory.dispatch_reply( - message['reply_channel'], - {'accept': True} - ) - - # Make sure that we get a 101 Switching Protocols back - response = self.tr.value() - self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response) - self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) - self.tr.clear() + response = self.connection.send({'accept': True}) + self.assert_websocket_upgrade(response) # Send some text - self.factory.dispatch_reply( - message['reply_channel'], - {'text': "Hello World!"} - ) - - response = self.tr.value() + response = self.connection.send({'text': "Hello World!"}) self.assertEqual(response, b"\x81\x0cHello World!") - self.tr.clear() # Send some bytes - self.factory.dispatch_reply( - message['reply_channel'], - {'bytes': b"\xaa\xbb\xcc\xdd"} - ) - - response = self.tr.value() + response = self.connection.send({'bytes': b"\xaa\xbb\xcc\xdd"}) self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd") - self.tr.clear() # Close the connection - self.factory.dispatch_reply( - message['reply_channel'], - {'close': True} - ) - - response = self.tr.value() + response = self.connection.send({'close': True}) self.assertEqual(response, b"\x88\x02\x03\xe8") - self.tr.clear() def test_connection_with_file_origin_is_accepted(self): - # Send a simple request to the protocol - self.proto.dataReceived( + message = self.connection.receive( b"GET /chat HTTP/1.1\r\n" b"Host: somewhere.com\r\n" b"Upgrade: websocket\r\n" @@ -103,26 +218,15 @@ class TestWebSocketProtocol(TestCase): b"Origin: file://\r\n" b"\r\n" ) - - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["websocket.connect"]) self.assertIn((b'origin', b'file://'), message['headers']) - self.assertTrue(message['reply_channel'].startswith("test!")) + self.assert_valid_websocket_connect_message(message, '/chat') # Accept the connection - self.factory.dispatch_reply( - message['reply_channel'], - {'accept': True} - ) - - # Make sure that we get a 101 Switching Protocols back - response = self.tr.value() - self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response) - self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) + response = self.connection.send({'accept': True}) + self.assert_websocket_upgrade(response) def test_connection_with_no_origin_is_accepted(self): - # Send a simple request to the protocol - self.proto.dataReceived( + message = self.connection.receive( b"GET /chat HTTP/1.1\r\n" b"Host: somewhere.com\r\n" b"Upgrade: websocket\r\n" @@ -133,18 +237,9 @@ class TestWebSocketProtocol(TestCase): b"\r\n" ) - # Get the resulting message off of the channel layer - _, message = self.channel_layer.receive(["websocket.connect"]) self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']]) - self.assertTrue(message['reply_channel'].startswith("test!")) + self.assert_valid_websocket_connect_message(message, '/chat') # Accept the connection - self.factory.dispatch_reply( - message['reply_channel'], - {'accept': True} - ) - - # Make sure that we get a 101 Switching Protocols back - response = self.tr.value() - self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response) - self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) + response = self.connection.send({'accept': True}) + self.assert_websocket_upgrade(response) diff --git a/daphne/tests/testcases.py b/daphne/tests/testcases.py index f1a372d..9e78795 100644 --- a/daphne/tests/testcases.py +++ b/daphne/tests/testcases.py @@ -12,9 +12,9 @@ import unittest from . import factories -class ASGITestCase(unittest.TestCase): +class ASGITestCaseBase(unittest.TestCase): """ - Test case with helpers for ASGI message verification + Base class for our test classes which contains shared method. """ def assert_is_ip_address(self, address): @@ -26,6 +26,35 @@ class ASGITestCase(unittest.TestCase): except socket.error: self.fail("'%s' is not a valid IP address." % address) + def assert_presence_of_message_keys(self, keys, required_keys, optional_keys): + present_keys = set(keys) + self.assertTrue(required_keys <= present_keys) + # Assert that no other keys are present + self.assertEqual(set(), present_keys - required_keys - optional_keys) + + def assert_valid_reply_channel(self, reply_channel): + self.assertIsInstance(reply_channel, six.text_type) + # The reply channel is decided by the server. + self.assertTrue(reply_channel.startswith('test!')) + + def assert_valid_path(self, path, request_path): + self.assertIsInstance(path, six.text_type) + self.assertEqual(path, request_path) + # Assert that it's already url decoded + self.assertEqual(path, parse.unquote(path)) + + def assert_valid_address_and_port(self, host): + address, port = host + self.assertIsInstance(address, six.text_type) + self.assert_is_ip_address(address) + self.assertIsInstance(port, int) + + +class ASGIHTTPTestCase(ASGITestCaseBase): + """ + Test case with helpers for verifying HTTP channel messages + """ + def assert_valid_http_request_message( self, channel_message, request_method, request_path, request_params=None, request_headers=None, request_body=None): @@ -35,22 +64,14 @@ class ASGITestCase(unittest.TestCase): self.assertTrue(channel_message) - # == General assertions about expected dictionary keys being present == - message_keys = set(channel_message.keys()) - required_message_keys = { - 'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers', - } - optional_message_keys = { - 'scheme', 'root_path', 'body', 'body_channel', 'client', 'server' - } - self.assertTrue(required_message_keys <= message_keys) - # Assert that no other keys are present - self.assertEqual(set(), message_keys - required_message_keys - optional_message_keys) + self.assert_presence_of_message_keys( + channel_message.keys(), + {'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers'}, + {'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'}) # == Assertions about required channel_message fields == - reply_channel = channel_message['reply_channel'] - self.assertIsInstance(reply_channel, six.text_type) - self.assertTrue(reply_channel.startswith('test!')) + self.assert_valid_reply_channel(channel_message['reply_channel']) + self.assert_valid_path(channel_message['path'], request_path) http_version = channel_message['http_version'] self.assertIsInstance(http_version, six.text_type) @@ -61,12 +82,6 @@ class ASGITestCase(unittest.TestCase): self.assertTrue(method.isupper()) self.assertEqual(channel_message['method'], request_method) - path = channel_message['path'] - self.assertIsInstance(path, six.text_type) - self.assertEqual(path, request_path) - # Assert that it's already url decoded - self.assertEqual(path, parse.unquote(path)) - query_string = channel_message['query_string'] # Assert that query_string is a byte string and still url encoded self.assertIsInstance(query_string, six.binary_type) @@ -112,17 +127,11 @@ class ASGITestCase(unittest.TestCase): client = channel_message.get('client') if client is not None: - client_host, client_port = client - self.assertIsInstance(client_host, six.text_type) - self.assert_is_ip_address(client_host) - self.assertIsInstance(client_port, int) + self.assert_valid_address_and_port(channel_message['client']) server = channel_message.get('server') if server is not None: - server_host, server_port = channel_message['server'] - self.assertIsInstance(server_host, six.text_type) - self.assert_is_ip_address(server_host) - self.assertIsInstance(server_port, int) + self.assert_valid_address_and_port(channel_message['server']) def assert_valid_http_response_message(self, message, response): self.assertTrue(message) @@ -147,3 +156,87 @@ class ASGITestCase(unittest.TestCase): # altered casing. The approach below does this well enough. self.assertIn(expected_header.lower(), response.lower()) self.assertIn(value.encode('ascii'), response) + + +class ASGIWebSocketTestCase(ASGITestCaseBase): + """ + Test case with helpers for verifying WebSocket channel messages + """ + + def assert_websocket_upgrade(self, response, body=b'', expect_close=False): + self.assertIn(b"HTTP/1.1 101 Switching Protocols", response) + self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) + self.assertIn(body, response) + self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8")) + + def assert_websocket_denied(self, response): + self.assertIn(b'HTTP/1.1 403', response) + + def assert_valid_websocket_connect_message( + self, channel_message, request_path='/', request_params=None, request_headers=None): + """ + Asserts that a given channel message conforms to the HTTP request section of the ASGI spec. + """ + + self.assertTrue(channel_message) + + self.assert_presence_of_message_keys( + channel_message.keys(), + {'reply_channel', 'path', 'headers', 'order'}, + {'scheme', 'query_string', 'root_path', 'client', 'server'}) + + # == Assertions about required channel_message fields == + self.assert_valid_reply_channel(channel_message['reply_channel']) + self.assert_valid_path(channel_message['path'], request_path) + + order = channel_message['order'] + self.assertIsInstance(order, int) + self.assertEqual(order, 0) + + # Ordering of header names is not important, but the order of values for a header + # name is. To assert whether that order is kept, we transform the request + # headers and the channel message headers into a set + # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal. + # Note that unlike for HTTP, Daphne never gives out individual header values; instead we + # get one string per header field with values separated by comma. + transformed_request_headers = defaultdict(list) + for name, value in (request_headers or []): + expected_name = name.lower().strip().encode('ascii') + expected_value = value.strip().encode('ascii') + transformed_request_headers[expected_name].append(expected_value) + final_request_headers = { + (name, b','.join(value)) for name, value in transformed_request_headers.items() + } + + # Websockets carry a lot of additional header fields, so instead of verifying that + # headers look exactly like expected, we just check that the expected header fields + # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed + # and not tested for. + assert final_request_headers.issubset(set(channel_message['headers'])) + + # == Assertions about optional channel_message fields == + scheme = channel_message.get('scheme') + if scheme: + self.assertIsInstance(scheme, six.text_type) + self.assertIn(scheme, ['ws', 'wss']) + + query_string = channel_message.get('query_string') + if query_string: + # Assert that query_string is a byte string and still url encoded + # TODO: It's neither a byte string nor urlencoded + # Will fail until https://github.com/django/daphne/issues/110 is resolved. + #self.assertIsInstance(query_string, six.binary_type) + #self.assertEqual(query_string, parse.urlencode(request_params or []).encode('ascii')) + pass + + root_path = channel_message.get('root_path') + if root_path is not None: + self.assertIsInstance(root_path, six.text_type) + + client = channel_message.get('client') + if client is not None: + self.assert_valid_address_and_port(channel_message['client']) + + server = channel_message.get('server') + if server is not None: + self.assert_valid_address_and_port(channel_message['server']) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 5e26737..6b630a2 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -50,9 +50,11 @@ class WebSocketProtocol(WebSocketServerProtocol): # Tell main factory about it self.main_factory.reply_protocols[self.reply_channel] = self # Get client address if possible - if hasattr(self.transport.getPeer(), "host") and hasattr(self.transport.getPeer(), "port"): - self.client_addr = [self.transport.getPeer().host, self.transport.getPeer().port] - self.server_addr = [self.transport.getHost().host, self.transport.getHost().port] + peer = self.transport.getPeer() + host = self.transport.getHost() + if hasattr(peer, "host") and hasattr(peer, "port"): + self.client_addr = [six.text_type(peer.host), peer.port] + self.server_addr = [six.text_type(host.host), host.port] else: self.client_addr = None self.server_addr = None