Websockets test and unicode fix for Python 2 (#111)

* Python 2 fix for host address

This is a copy of
57051a48cd
for the Websocket protocol.

In Python 2, Twisted returns a byte string for the host address, while
the spec requires a unicode string. A simple cast gives us consistency.

* Test suite for websocket tests

This commit

* introduces some new helpers to test the Websocket protocol
* renames the old ASGITestCase class to ASGIHTTPTestCase, and
  introduces a test case for testing Websockets
* moves some helper methods that are shared between HTTP and Websockets
  into a mutual base class
* uses the new helpers to simplfiy the existing tests
* and adds a couple new tests.
This commit is contained in:
Maik Hoepfel 2017-04-28 23:45:07 +02:00 committed by Andrew Godwin
parent bd03fabce6
commit 2bcec3fe94
7 changed files with 321 additions and 109 deletions

View File

@ -84,6 +84,19 @@ def _build_request(method, path, params=None, headers=None, body=None):
return request
def build_websocket_upgrade(path, params, headers):
ws_headers = [
('Host', 'somewhere.com'),
('Upgrade', 'websocket'),
('Connection', 'Upgrade'),
('Sec-WebSocket-Key', 'x3JJHMbDL1EzLkh9GBhXDw=='),
('Sec-WebSocket-Protocol', 'chat, superchat'),
('Sec-WebSocket-Version', '13'),
('Origin', 'http://example.com')
]
return _build_request('GET', path, params, headers=headers + ws_headers, body=None)
def header_line(name, value):
"""
Given a header name and value, returns the line to use in a HTTP request or response.

View File

@ -17,12 +17,17 @@ def http_method():
return strategies.sampled_from(HTTP_METHODS)
def _http_path_portion():
alphabet = string.ascii_letters + string.digits + '-._~'
return strategies.text(min_size=1, average_size=10, max_size=128, alphabet=alphabet)
def http_path():
"""
Returns a URL path (not encoded).
"""
alphabet = string.ascii_letters + string.digits + '-._~/'
return strategies.text(min_size=0, max_size=255, alphabet=alphabet).map(lambda s: '/' + s)
return strategies.lists(
_http_path_portion(), min_size=0, max_size=10).map(lambda s: '/' + '/'.join(s))
def http_body():
@ -33,6 +38,10 @@ def http_body():
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
def binary_payload():
return strategies.binary(min_size=0, average_size=600, max_size=1500)
def valid_bidi(value):
"""
Rejects strings which nonsensical Unicode text direction flags.

View File

@ -16,7 +16,7 @@ from daphne.tests import testcases, http_strategies
from daphne.tests.factories import message_for_request, content_length_header
class TestHTTPRequestSpec(testcases.ASGITestCase):
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
"""
Tests which try to pour the HTTP request section of the ASGI spec into code.
The heavy lifting is done by the assert_valid_http_request_message function,

View File

@ -14,7 +14,7 @@ from daphne.http_protocol import HTTPFactory
from . import factories, http_strategies, testcases
class TestHTTPResponseSpec(testcases.ASGITestCase):
class TestHTTPResponseSpec(testcases.ASGIHTTPTestCase):
def test_minimal_response(self):
"""

View File

@ -1,27 +1,171 @@
# coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
from asgiref.inmemory import ChannelLayer
from hypothesis import assume, given, strategies
from twisted.test import proto_helpers
from asgiref.inmemory import ChannelLayer
from daphne.http_protocol import HTTPFactory
from daphne.tests import http_strategies, testcases, factories
class TestWebSocketProtocol(TestCase):
class WebSocketConnection(object):
"""
Tests that the WebSocket protocol class correcly generates and parses messages.
Helper class that makes it easier to test Dahpne's WebSocket support.
"""
def setUp(self):
def __init__(self):
self.last_message = None
self.channel_layer = ChannelLayer()
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
self.transport = proto_helpers.StringTransport()
self.proto.makeConnection(self.transport)
def receive(self, request):
"""
Low-level method to let Daphne handle HTTP/WebSocket data
"""
self.proto.dataReceived(request)
_, self.last_message = self.channel_layer.receive(['websocket.connect'])
return self.last_message
def send(self, content):
"""
Method to respond with a channel message
"""
if self.last_message is None:
# Auto-connect for convenience.
self.connect()
self.factory.dispatch_reply(self.last_message['reply_channel'], content)
response = self.transport.value()
self.transport.clear()
return response
def connect(self, path='/', params=None, headers=None):
"""
High-level method to perform the WebSocket handshake
"""
request = factories.build_websocket_upgrade(path, params, headers or [])
message = self.receive(request)
return message
class TestHandshake(testcases.ASGIWebSocketTestCase):
"""
Tests for the WebSocket handshake
"""
def test_minimal(self):
message = WebSocketConnection().connect()
self.assert_valid_websocket_connect_message(message)
@given(
path=http_strategies.http_path(),
params=http_strategies.query_params(),
headers=http_strategies.headers(),
)
def test_connection(self, path, params, headers):
message = WebSocketConnection().connect(path, params, headers)
self.assert_valid_websocket_connect_message(message, path, params, headers)
class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
"""
Tests that, essentially, try to translate the send/close/accept section of the spec into code.
"""
def test_empty_accept(self):
response = WebSocketConnection().send({'accept': True})
self.assert_websocket_upgrade(response)
@given(text=http_strategies.http_body())
def test_accept_and_text(self, text):
response = WebSocketConnection().send({'accept': True, 'text': text})
self.assert_websocket_upgrade(response, text.encode('ascii'))
@given(data=http_strategies.binary_payload())
def test_accept_and_bytes(self, data):
response = WebSocketConnection().send({'accept': True, 'bytes': data})
self.assert_websocket_upgrade(response, data)
def test_accept_false(self):
response = WebSocketConnection().send({'accept': False})
self.assert_websocket_denied(response)
def test_accept_false_with_text(self):
"""
Tests that even if text is given, the connection is denied.
We can't easily use Hypothesis to generate data for this test because it's
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
"""
text = 'foobar'
response = WebSocketConnection().send({'accept': False, 'text': text})
self.assert_websocket_denied(response)
self.assertNotIn(text.encode('ascii'), response)
def test_accept_false_with_bytes(self):
"""
Tests that even if data is given, the connection is denied.
We can't easily use Hypothesis to generate data for this test because it's
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
"""
data = b'foobar'
response = WebSocketConnection().send({'accept': False, 'bytes': data})
self.assert_websocket_denied(response)
self.assertNotIn(data, response)
@given(text=http_strategies.http_body())
def test_just_text(self, text):
assume(len(text) > 0)
# If content is sent, accept=True is implied.
response = WebSocketConnection().send({'text': text})
self.assert_websocket_upgrade(response, text.encode('ascii'))
@given(data=http_strategies.binary_payload())
def test_just_bytes(self, data):
assume(len(data) > 0)
# If content is sent, accept=True is implied.
response = WebSocketConnection().send({'bytes': data})
self.assert_websocket_upgrade(response, data)
def test_close_boolean(self):
response = WebSocketConnection().send({'close': True})
self.assert_websocket_denied(response)
@given(number=strategies.integers(min_value=1))
def test_close_integer(self, number):
response = WebSocketConnection().send({'close': number})
self.assert_websocket_denied(response)
@given(text=http_strategies.http_body())
def test_close_with_text(self, text):
assume(len(text) > 0)
response = WebSocketConnection().send({'close': True, 'text': text})
self.assert_websocket_upgrade(response, text.encode('ascii'), expect_close=True)
@given(data=http_strategies.binary_payload())
def test_close_with_data(self, data):
assume(len(data) > 0)
response = WebSocketConnection().send({'close': True, 'bytes': data})
self.assert_websocket_upgrade(response, data, expect_close=True)
class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
"""
Tests that the WebSocket protocol class correctly generates and parses messages.
"""
def setUp(self):
self.connection = WebSocketConnection()
def test_basic(self):
# Send a simple request to the protocol
self.proto.dataReceived(
# Send a simple request to the protocol and get the resulting message off
# of the channel layer.
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
@ -32,8 +176,6 @@ class TestWebSocketProtocol(TestCase):
b"Origin: http://example.com\r\n"
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["websocket.connect"])
self.assertEqual(message['path'], "/chat")
self.assertEqual(message['query_string'], "")
self.assertEqual(
@ -46,53 +188,26 @@ class TestWebSocketProtocol(TestCase):
(b'sec-websocket-version', b'13'),
(b'upgrade', b'websocket')]
)
self.assertTrue(message['reply_channel'].startswith("test!"))
self.assert_valid_websocket_connect_message(message, '/chat')
# Accept the connection
self.factory.dispatch_reply(
message['reply_channel'],
{'accept': True}
)
# Make sure that we get a 101 Switching Protocols back
response = self.tr.value()
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
self.tr.clear()
response = self.connection.send({'accept': True})
self.assert_websocket_upgrade(response)
# Send some text
self.factory.dispatch_reply(
message['reply_channel'],
{'text': "Hello World!"}
)
response = self.tr.value()
response = self.connection.send({'text': "Hello World!"})
self.assertEqual(response, b"\x81\x0cHello World!")
self.tr.clear()
# Send some bytes
self.factory.dispatch_reply(
message['reply_channel'],
{'bytes': b"\xaa\xbb\xcc\xdd"}
)
response = self.tr.value()
response = self.connection.send({'bytes': b"\xaa\xbb\xcc\xdd"})
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
self.tr.clear()
# Close the connection
self.factory.dispatch_reply(
message['reply_channel'],
{'close': True}
)
response = self.tr.value()
response = self.connection.send({'close': True})
self.assertEqual(response, b"\x88\x02\x03\xe8")
self.tr.clear()
def test_connection_with_file_origin_is_accepted(self):
# Send a simple request to the protocol
self.proto.dataReceived(
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
@ -103,26 +218,15 @@ class TestWebSocketProtocol(TestCase):
b"Origin: file://\r\n"
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["websocket.connect"])
self.assertIn((b'origin', b'file://'), message['headers'])
self.assertTrue(message['reply_channel'].startswith("test!"))
self.assert_valid_websocket_connect_message(message, '/chat')
# Accept the connection
self.factory.dispatch_reply(
message['reply_channel'],
{'accept': True}
)
# Make sure that we get a 101 Switching Protocols back
response = self.tr.value()
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
response = self.connection.send({'accept': True})
self.assert_websocket_upgrade(response)
def test_connection_with_no_origin_is_accepted(self):
# Send a simple request to the protocol
self.proto.dataReceived(
message = self.connection.receive(
b"GET /chat HTTP/1.1\r\n"
b"Host: somewhere.com\r\n"
b"Upgrade: websocket\r\n"
@ -133,18 +237,9 @@ class TestWebSocketProtocol(TestCase):
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive(["websocket.connect"])
self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']])
self.assertTrue(message['reply_channel'].startswith("test!"))
self.assert_valid_websocket_connect_message(message, '/chat')
# Accept the connection
self.factory.dispatch_reply(
message['reply_channel'],
{'accept': True}
)
# Make sure that we get a 101 Switching Protocols back
response = self.tr.value()
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
response = self.connection.send({'accept': True})
self.assert_websocket_upgrade(response)

View File

@ -12,9 +12,9 @@ import unittest
from . import factories
class ASGITestCase(unittest.TestCase):
class ASGITestCaseBase(unittest.TestCase):
"""
Test case with helpers for ASGI message verification
Base class for our test classes which contains shared method.
"""
def assert_is_ip_address(self, address):
@ -26,6 +26,35 @@ class ASGITestCase(unittest.TestCase):
except socket.error:
self.fail("'%s' is not a valid IP address." % address)
def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
present_keys = set(keys)
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_reply_channel(self, reply_channel):
self.assertIsInstance(reply_channel, six.text_type)
# The reply channel is decided by the server.
self.assertTrue(reply_channel.startswith('test!'))
def assert_valid_path(self, path, request_path):
self.assertIsInstance(path, six.text_type)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
address, port = host
self.assertIsInstance(address, six.text_type)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)
class ASGIHTTPTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying HTTP channel messages
"""
def assert_valid_http_request_message(
self, channel_message, request_method, request_path,
request_params=None, request_headers=None, request_body=None):
@ -35,22 +64,14 @@ class ASGITestCase(unittest.TestCase):
self.assertTrue(channel_message)
# == General assertions about expected dictionary keys being present ==
message_keys = set(channel_message.keys())
required_message_keys = {
'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers',
}
optional_message_keys = {
'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'
}
self.assertTrue(required_message_keys <= message_keys)
# Assert that no other keys are present
self.assertEqual(set(), message_keys - required_message_keys - optional_message_keys)
self.assert_presence_of_message_keys(
channel_message.keys(),
{'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers'},
{'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'})
# == Assertions about required channel_message fields ==
reply_channel = channel_message['reply_channel']
self.assertIsInstance(reply_channel, six.text_type)
self.assertTrue(reply_channel.startswith('test!'))
self.assert_valid_reply_channel(channel_message['reply_channel'])
self.assert_valid_path(channel_message['path'], request_path)
http_version = channel_message['http_version']
self.assertIsInstance(http_version, six.text_type)
@ -61,12 +82,6 @@ class ASGITestCase(unittest.TestCase):
self.assertTrue(method.isupper())
self.assertEqual(channel_message['method'], request_method)
path = channel_message['path']
self.assertIsInstance(path, six.text_type)
self.assertEqual(path, request_path)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
query_string = channel_message['query_string']
# Assert that query_string is a byte string and still url encoded
self.assertIsInstance(query_string, six.binary_type)
@ -112,17 +127,11 @@ class ASGITestCase(unittest.TestCase):
client = channel_message.get('client')
if client is not None:
client_host, client_port = client
self.assertIsInstance(client_host, six.text_type)
self.assert_is_ip_address(client_host)
self.assertIsInstance(client_port, int)
self.assert_valid_address_and_port(channel_message['client'])
server = channel_message.get('server')
if server is not None:
server_host, server_port = channel_message['server']
self.assertIsInstance(server_host, six.text_type)
self.assert_is_ip_address(server_host)
self.assertIsInstance(server_port, int)
self.assert_valid_address_and_port(channel_message['server'])
def assert_valid_http_response_message(self, message, response):
self.assertTrue(message)
@ -147,3 +156,87 @@ class ASGITestCase(unittest.TestCase):
# altered casing. The approach below does this well enough.
self.assertIn(expected_header.lower(), response.lower())
self.assertIn(value.encode('ascii'), response)
class ASGIWebSocketTestCase(ASGITestCaseBase):
"""
Test case with helpers for verifying WebSocket channel messages
"""
def assert_websocket_upgrade(self, response, body=b'', expect_close=False):
self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
self.assertIn(body, response)
self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
def assert_websocket_denied(self, response):
self.assertIn(b'HTTP/1.1 403', response)
def assert_valid_websocket_connect_message(
self, channel_message, request_path='/', request_params=None, request_headers=None):
"""
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
"""
self.assertTrue(channel_message)
self.assert_presence_of_message_keys(
channel_message.keys(),
{'reply_channel', 'path', 'headers', 'order'},
{'scheme', 'query_string', 'root_path', 'client', 'server'})
# == Assertions about required channel_message fields ==
self.assert_valid_reply_channel(channel_message['reply_channel'])
self.assert_valid_path(channel_message['path'], request_path)
order = channel_message['order']
self.assertIsInstance(order, int)
self.assertEqual(order, 0)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform the request
# headers and the channel message headers into a set
# {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
# Note that unlike for HTTP, Daphne never gives out individual header values; instead we
# get one string per header field with values separated by comma.
transformed_request_headers = defaultdict(list)
for name, value in (request_headers or []):
expected_name = name.lower().strip().encode('ascii')
expected_value = value.strip().encode('ascii')
transformed_request_headers[expected_name].append(expected_value)
final_request_headers = {
(name, b','.join(value)) for name, value in transformed_request_headers.items()
}
# Websockets carry a lot of additional header fields, so instead of verifying that
# headers look exactly like expected, we just check that the expected header fields
# and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
# and not tested for.
assert final_request_headers.issubset(set(channel_message['headers']))
# == Assertions about optional channel_message fields ==
scheme = channel_message.get('scheme')
if scheme:
self.assertIsInstance(scheme, six.text_type)
self.assertIn(scheme, ['ws', 'wss'])
query_string = channel_message.get('query_string')
if query_string:
# Assert that query_string is a byte string and still url encoded
# TODO: It's neither a byte string nor urlencoded
# Will fail until https://github.com/django/daphne/issues/110 is resolved.
#self.assertIsInstance(query_string, six.binary_type)
#self.assertEqual(query_string, parse.urlencode(request_params or []).encode('ascii'))
pass
root_path = channel_message.get('root_path')
if root_path is not None:
self.assertIsInstance(root_path, six.text_type)
client = channel_message.get('client')
if client is not None:
self.assert_valid_address_and_port(channel_message['client'])
server = channel_message.get('server')
if server is not None:
self.assert_valid_address_and_port(channel_message['server'])

View File

@ -50,9 +50,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
# Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self
# Get client address if possible
if hasattr(self.transport.getPeer(), "host") and hasattr(self.transport.getPeer(), "port"):
self.client_addr = [self.transport.getPeer().host, self.transport.getPeer().port]
self.server_addr = [self.transport.getHost().host, self.transport.getHost().port]
peer = self.transport.getPeer()
host = self.transport.getHost()
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [six.text_type(peer.host), peer.port]
self.server_addr = [six.text_type(host.host), host.port]
else:
self.client_addr = None
self.server_addr = None