mirror of
https://github.com/django/daphne.git
synced 2024-11-25 09:13:44 +03:00
Websockets test and unicode fix for Python 2 (#111)
* Python 2 fix for host address
This is a copy of
57051a48cd
for the Websocket protocol.
In Python 2, Twisted returns a byte string for the host address, while
the spec requires a unicode string. A simple cast gives us consistency.
* Test suite for websocket tests
This commit
* introduces some new helpers to test the Websocket protocol
* renames the old ASGITestCase class to ASGIHTTPTestCase, and
introduces a test case for testing Websockets
* moves some helper methods that are shared between HTTP and Websockets
into a mutual base class
* uses the new helpers to simplfiy the existing tests
* and adds a couple new tests.
This commit is contained in:
parent
bd03fabce6
commit
2bcec3fe94
|
@ -84,6 +84,19 @@ def _build_request(method, path, params=None, headers=None, body=None):
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
def build_websocket_upgrade(path, params, headers):
|
||||||
|
ws_headers = [
|
||||||
|
('Host', 'somewhere.com'),
|
||||||
|
('Upgrade', 'websocket'),
|
||||||
|
('Connection', 'Upgrade'),
|
||||||
|
('Sec-WebSocket-Key', 'x3JJHMbDL1EzLkh9GBhXDw=='),
|
||||||
|
('Sec-WebSocket-Protocol', 'chat, superchat'),
|
||||||
|
('Sec-WebSocket-Version', '13'),
|
||||||
|
('Origin', 'http://example.com')
|
||||||
|
]
|
||||||
|
return _build_request('GET', path, params, headers=headers + ws_headers, body=None)
|
||||||
|
|
||||||
|
|
||||||
def header_line(name, value):
|
def header_line(name, value):
|
||||||
"""
|
"""
|
||||||
Given a header name and value, returns the line to use in a HTTP request or response.
|
Given a header name and value, returns the line to use in a HTTP request or response.
|
||||||
|
|
|
@ -17,12 +17,17 @@ def http_method():
|
||||||
return strategies.sampled_from(HTTP_METHODS)
|
return strategies.sampled_from(HTTP_METHODS)
|
||||||
|
|
||||||
|
|
||||||
|
def _http_path_portion():
|
||||||
|
alphabet = string.ascii_letters + string.digits + '-._~'
|
||||||
|
return strategies.text(min_size=1, average_size=10, max_size=128, alphabet=alphabet)
|
||||||
|
|
||||||
|
|
||||||
def http_path():
|
def http_path():
|
||||||
"""
|
"""
|
||||||
Returns a URL path (not encoded).
|
Returns a URL path (not encoded).
|
||||||
"""
|
"""
|
||||||
alphabet = string.ascii_letters + string.digits + '-._~/'
|
return strategies.lists(
|
||||||
return strategies.text(min_size=0, max_size=255, alphabet=alphabet).map(lambda s: '/' + s)
|
_http_path_portion(), min_size=0, max_size=10).map(lambda s: '/' + '/'.join(s))
|
||||||
|
|
||||||
|
|
||||||
def http_body():
|
def http_body():
|
||||||
|
@ -33,6 +38,10 @@ def http_body():
|
||||||
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
|
return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
|
||||||
|
|
||||||
|
|
||||||
|
def binary_payload():
|
||||||
|
return strategies.binary(min_size=0, average_size=600, max_size=1500)
|
||||||
|
|
||||||
|
|
||||||
def valid_bidi(value):
|
def valid_bidi(value):
|
||||||
"""
|
"""
|
||||||
Rejects strings which nonsensical Unicode text direction flags.
|
Rejects strings which nonsensical Unicode text direction flags.
|
||||||
|
|
|
@ -16,7 +16,7 @@ from daphne.tests import testcases, http_strategies
|
||||||
from daphne.tests.factories import message_for_request, content_length_header
|
from daphne.tests.factories import message_for_request, content_length_header
|
||||||
|
|
||||||
|
|
||||||
class TestHTTPRequestSpec(testcases.ASGITestCase):
|
class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
|
||||||
"""
|
"""
|
||||||
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
Tests which try to pour the HTTP request section of the ASGI spec into code.
|
||||||
The heavy lifting is done by the assert_valid_http_request_message function,
|
The heavy lifting is done by the assert_valid_http_request_message function,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from daphne.http_protocol import HTTPFactory
|
||||||
from . import factories, http_strategies, testcases
|
from . import factories, http_strategies, testcases
|
||||||
|
|
||||||
|
|
||||||
class TestHTTPResponseSpec(testcases.ASGITestCase):
|
class TestHTTPResponseSpec(testcases.ASGIHTTPTestCase):
|
||||||
|
|
||||||
def test_minimal_response(self):
|
def test_minimal_response(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,27 +1,171 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from unittest import TestCase
|
|
||||||
from asgiref.inmemory import ChannelLayer
|
from hypothesis import assume, given, strategies
|
||||||
from twisted.test import proto_helpers
|
from twisted.test import proto_helpers
|
||||||
|
|
||||||
|
from asgiref.inmemory import ChannelLayer
|
||||||
from daphne.http_protocol import HTTPFactory
|
from daphne.http_protocol import HTTPFactory
|
||||||
|
from daphne.tests import http_strategies, testcases, factories
|
||||||
|
|
||||||
|
|
||||||
class TestWebSocketProtocol(TestCase):
|
class WebSocketConnection(object):
|
||||||
"""
|
"""
|
||||||
Tests that the WebSocket protocol class correcly generates and parses messages.
|
Helper class that makes it easier to test Dahpne's WebSocket support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def __init__(self):
|
||||||
|
self.last_message = None
|
||||||
|
|
||||||
self.channel_layer = ChannelLayer()
|
self.channel_layer = ChannelLayer()
|
||||||
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
|
||||||
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
|
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
|
||||||
self.tr = proto_helpers.StringTransport()
|
self.transport = proto_helpers.StringTransport()
|
||||||
self.proto.makeConnection(self.tr)
|
self.proto.makeConnection(self.transport)
|
||||||
|
|
||||||
|
def receive(self, request):
|
||||||
|
"""
|
||||||
|
Low-level method to let Daphne handle HTTP/WebSocket data
|
||||||
|
"""
|
||||||
|
self.proto.dataReceived(request)
|
||||||
|
_, self.last_message = self.channel_layer.receive(['websocket.connect'])
|
||||||
|
return self.last_message
|
||||||
|
|
||||||
|
def send(self, content):
|
||||||
|
"""
|
||||||
|
Method to respond with a channel message
|
||||||
|
"""
|
||||||
|
if self.last_message is None:
|
||||||
|
# Auto-connect for convenience.
|
||||||
|
self.connect()
|
||||||
|
self.factory.dispatch_reply(self.last_message['reply_channel'], content)
|
||||||
|
response = self.transport.value()
|
||||||
|
self.transport.clear()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def connect(self, path='/', params=None, headers=None):
|
||||||
|
"""
|
||||||
|
High-level method to perform the WebSocket handshake
|
||||||
|
"""
|
||||||
|
request = factories.build_websocket_upgrade(path, params, headers or [])
|
||||||
|
message = self.receive(request)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandshake(testcases.ASGIWebSocketTestCase):
|
||||||
|
"""
|
||||||
|
Tests for the WebSocket handshake
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_minimal(self):
|
||||||
|
message = WebSocketConnection().connect()
|
||||||
|
self.assert_valid_websocket_connect_message(message)
|
||||||
|
|
||||||
|
@given(
|
||||||
|
path=http_strategies.http_path(),
|
||||||
|
params=http_strategies.query_params(),
|
||||||
|
headers=http_strategies.headers(),
|
||||||
|
)
|
||||||
|
def test_connection(self, path, params, headers):
|
||||||
|
message = WebSocketConnection().connect(path, params, headers)
|
||||||
|
self.assert_valid_websocket_connect_message(message, path, params, headers)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
|
||||||
|
"""
|
||||||
|
Tests that, essentially, try to translate the send/close/accept section of the spec into code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_empty_accept(self):
|
||||||
|
response = WebSocketConnection().send({'accept': True})
|
||||||
|
self.assert_websocket_upgrade(response)
|
||||||
|
|
||||||
|
@given(text=http_strategies.http_body())
|
||||||
|
def test_accept_and_text(self, text):
|
||||||
|
response = WebSocketConnection().send({'accept': True, 'text': text})
|
||||||
|
self.assert_websocket_upgrade(response, text.encode('ascii'))
|
||||||
|
|
||||||
|
@given(data=http_strategies.binary_payload())
|
||||||
|
def test_accept_and_bytes(self, data):
|
||||||
|
response = WebSocketConnection().send({'accept': True, 'bytes': data})
|
||||||
|
self.assert_websocket_upgrade(response, data)
|
||||||
|
|
||||||
|
def test_accept_false(self):
|
||||||
|
response = WebSocketConnection().send({'accept': False})
|
||||||
|
self.assert_websocket_denied(response)
|
||||||
|
|
||||||
|
def test_accept_false_with_text(self):
|
||||||
|
"""
|
||||||
|
Tests that even if text is given, the connection is denied.
|
||||||
|
|
||||||
|
We can't easily use Hypothesis to generate data for this test because it's
|
||||||
|
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
||||||
|
"""
|
||||||
|
text = 'foobar'
|
||||||
|
response = WebSocketConnection().send({'accept': False, 'text': text})
|
||||||
|
self.assert_websocket_denied(response)
|
||||||
|
self.assertNotIn(text.encode('ascii'), response)
|
||||||
|
|
||||||
|
def test_accept_false_with_bytes(self):
|
||||||
|
"""
|
||||||
|
Tests that even if data is given, the connection is denied.
|
||||||
|
|
||||||
|
We can't easily use Hypothesis to generate data for this test because it's
|
||||||
|
hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
|
||||||
|
"""
|
||||||
|
data = b'foobar'
|
||||||
|
response = WebSocketConnection().send({'accept': False, 'bytes': data})
|
||||||
|
self.assert_websocket_denied(response)
|
||||||
|
self.assertNotIn(data, response)
|
||||||
|
|
||||||
|
@given(text=http_strategies.http_body())
|
||||||
|
def test_just_text(self, text):
|
||||||
|
assume(len(text) > 0)
|
||||||
|
# If content is sent, accept=True is implied.
|
||||||
|
response = WebSocketConnection().send({'text': text})
|
||||||
|
self.assert_websocket_upgrade(response, text.encode('ascii'))
|
||||||
|
|
||||||
|
@given(data=http_strategies.binary_payload())
|
||||||
|
def test_just_bytes(self, data):
|
||||||
|
assume(len(data) > 0)
|
||||||
|
# If content is sent, accept=True is implied.
|
||||||
|
response = WebSocketConnection().send({'bytes': data})
|
||||||
|
self.assert_websocket_upgrade(response, data)
|
||||||
|
|
||||||
|
def test_close_boolean(self):
|
||||||
|
response = WebSocketConnection().send({'close': True})
|
||||||
|
self.assert_websocket_denied(response)
|
||||||
|
|
||||||
|
@given(number=strategies.integers(min_value=1))
|
||||||
|
def test_close_integer(self, number):
|
||||||
|
response = WebSocketConnection().send({'close': number})
|
||||||
|
self.assert_websocket_denied(response)
|
||||||
|
|
||||||
|
@given(text=http_strategies.http_body())
|
||||||
|
def test_close_with_text(self, text):
|
||||||
|
assume(len(text) > 0)
|
||||||
|
response = WebSocketConnection().send({'close': True, 'text': text})
|
||||||
|
self.assert_websocket_upgrade(response, text.encode('ascii'), expect_close=True)
|
||||||
|
|
||||||
|
@given(data=http_strategies.binary_payload())
|
||||||
|
def test_close_with_data(self, data):
|
||||||
|
assume(len(data) > 0)
|
||||||
|
response = WebSocketConnection().send({'close': True, 'bytes': data})
|
||||||
|
self.assert_websocket_upgrade(response, data, expect_close=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
|
||||||
|
"""
|
||||||
|
Tests that the WebSocket protocol class correctly generates and parses messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.connection = WebSocketConnection()
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
# Send a simple request to the protocol
|
# Send a simple request to the protocol and get the resulting message off
|
||||||
self.proto.dataReceived(
|
# of the channel layer.
|
||||||
|
message = self.connection.receive(
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
b"GET /chat HTTP/1.1\r\n"
|
||||||
b"Host: somewhere.com\r\n"
|
b"Host: somewhere.com\r\n"
|
||||||
b"Upgrade: websocket\r\n"
|
b"Upgrade: websocket\r\n"
|
||||||
|
@ -32,8 +176,6 @@ class TestWebSocketProtocol(TestCase):
|
||||||
b"Origin: http://example.com\r\n"
|
b"Origin: http://example.com\r\n"
|
||||||
b"\r\n"
|
b"\r\n"
|
||||||
)
|
)
|
||||||
# Get the resulting message off of the channel layer
|
|
||||||
_, message = self.channel_layer.receive(["websocket.connect"])
|
|
||||||
self.assertEqual(message['path'], "/chat")
|
self.assertEqual(message['path'], "/chat")
|
||||||
self.assertEqual(message['query_string'], "")
|
self.assertEqual(message['query_string'], "")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -46,53 +188,26 @@ class TestWebSocketProtocol(TestCase):
|
||||||
(b'sec-websocket-version', b'13'),
|
(b'sec-websocket-version', b'13'),
|
||||||
(b'upgrade', b'websocket')]
|
(b'upgrade', b'websocket')]
|
||||||
)
|
)
|
||||||
self.assertTrue(message['reply_channel'].startswith("test!"))
|
self.assert_valid_websocket_connect_message(message, '/chat')
|
||||||
|
|
||||||
# Accept the connection
|
# Accept the connection
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'accept': True})
|
||||||
message['reply_channel'],
|
self.assert_websocket_upgrade(response)
|
||||||
{'accept': True}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make sure that we get a 101 Switching Protocols back
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
|
|
||||||
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
|
||||||
self.tr.clear()
|
|
||||||
|
|
||||||
# Send some text
|
# Send some text
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'text': "Hello World!"})
|
||||||
message['reply_channel'],
|
|
||||||
{'text': "Hello World!"}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertEqual(response, b"\x81\x0cHello World!")
|
self.assertEqual(response, b"\x81\x0cHello World!")
|
||||||
self.tr.clear()
|
|
||||||
|
|
||||||
# Send some bytes
|
# Send some bytes
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'bytes': b"\xaa\xbb\xcc\xdd"})
|
||||||
message['reply_channel'],
|
|
||||||
{'bytes': b"\xaa\xbb\xcc\xdd"}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
|
self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
|
||||||
self.tr.clear()
|
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'close': True})
|
||||||
message['reply_channel'],
|
|
||||||
{'close': True}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertEqual(response, b"\x88\x02\x03\xe8")
|
self.assertEqual(response, b"\x88\x02\x03\xe8")
|
||||||
self.tr.clear()
|
|
||||||
|
|
||||||
def test_connection_with_file_origin_is_accepted(self):
|
def test_connection_with_file_origin_is_accepted(self):
|
||||||
# Send a simple request to the protocol
|
message = self.connection.receive(
|
||||||
self.proto.dataReceived(
|
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
b"GET /chat HTTP/1.1\r\n"
|
||||||
b"Host: somewhere.com\r\n"
|
b"Host: somewhere.com\r\n"
|
||||||
b"Upgrade: websocket\r\n"
|
b"Upgrade: websocket\r\n"
|
||||||
|
@ -103,26 +218,15 @@ class TestWebSocketProtocol(TestCase):
|
||||||
b"Origin: file://\r\n"
|
b"Origin: file://\r\n"
|
||||||
b"\r\n"
|
b"\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the resulting message off of the channel layer
|
|
||||||
_, message = self.channel_layer.receive(["websocket.connect"])
|
|
||||||
self.assertIn((b'origin', b'file://'), message['headers'])
|
self.assertIn((b'origin', b'file://'), message['headers'])
|
||||||
self.assertTrue(message['reply_channel'].startswith("test!"))
|
self.assert_valid_websocket_connect_message(message, '/chat')
|
||||||
|
|
||||||
# Accept the connection
|
# Accept the connection
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'accept': True})
|
||||||
message['reply_channel'],
|
self.assert_websocket_upgrade(response)
|
||||||
{'accept': True}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make sure that we get a 101 Switching Protocols back
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
|
|
||||||
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
|
||||||
|
|
||||||
def test_connection_with_no_origin_is_accepted(self):
|
def test_connection_with_no_origin_is_accepted(self):
|
||||||
# Send a simple request to the protocol
|
message = self.connection.receive(
|
||||||
self.proto.dataReceived(
|
|
||||||
b"GET /chat HTTP/1.1\r\n"
|
b"GET /chat HTTP/1.1\r\n"
|
||||||
b"Host: somewhere.com\r\n"
|
b"Host: somewhere.com\r\n"
|
||||||
b"Upgrade: websocket\r\n"
|
b"Upgrade: websocket\r\n"
|
||||||
|
@ -133,18 +237,9 @@ class TestWebSocketProtocol(TestCase):
|
||||||
b"\r\n"
|
b"\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the resulting message off of the channel layer
|
|
||||||
_, message = self.channel_layer.receive(["websocket.connect"])
|
|
||||||
self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']])
|
self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']])
|
||||||
self.assertTrue(message['reply_channel'].startswith("test!"))
|
self.assert_valid_websocket_connect_message(message, '/chat')
|
||||||
|
|
||||||
# Accept the connection
|
# Accept the connection
|
||||||
self.factory.dispatch_reply(
|
response = self.connection.send({'accept': True})
|
||||||
message['reply_channel'],
|
self.assert_websocket_upgrade(response)
|
||||||
{'accept': True}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make sure that we get a 101 Switching Protocols back
|
|
||||||
response = self.tr.value()
|
|
||||||
self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
|
|
||||||
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
|
||||||
|
|
|
@ -12,9 +12,9 @@ import unittest
|
||||||
from . import factories
|
from . import factories
|
||||||
|
|
||||||
|
|
||||||
class ASGITestCase(unittest.TestCase):
|
class ASGITestCaseBase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case with helpers for ASGI message verification
|
Base class for our test classes which contains shared method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def assert_is_ip_address(self, address):
|
def assert_is_ip_address(self, address):
|
||||||
|
@ -26,6 +26,35 @@ class ASGITestCase(unittest.TestCase):
|
||||||
except socket.error:
|
except socket.error:
|
||||||
self.fail("'%s' is not a valid IP address." % address)
|
self.fail("'%s' is not a valid IP address." % address)
|
||||||
|
|
||||||
|
def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
|
||||||
|
present_keys = set(keys)
|
||||||
|
self.assertTrue(required_keys <= present_keys)
|
||||||
|
# Assert that no other keys are present
|
||||||
|
self.assertEqual(set(), present_keys - required_keys - optional_keys)
|
||||||
|
|
||||||
|
def assert_valid_reply_channel(self, reply_channel):
|
||||||
|
self.assertIsInstance(reply_channel, six.text_type)
|
||||||
|
# The reply channel is decided by the server.
|
||||||
|
self.assertTrue(reply_channel.startswith('test!'))
|
||||||
|
|
||||||
|
def assert_valid_path(self, path, request_path):
|
||||||
|
self.assertIsInstance(path, six.text_type)
|
||||||
|
self.assertEqual(path, request_path)
|
||||||
|
# Assert that it's already url decoded
|
||||||
|
self.assertEqual(path, parse.unquote(path))
|
||||||
|
|
||||||
|
def assert_valid_address_and_port(self, host):
|
||||||
|
address, port = host
|
||||||
|
self.assertIsInstance(address, six.text_type)
|
||||||
|
self.assert_is_ip_address(address)
|
||||||
|
self.assertIsInstance(port, int)
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIHTTPTestCase(ASGITestCaseBase):
|
||||||
|
"""
|
||||||
|
Test case with helpers for verifying HTTP channel messages
|
||||||
|
"""
|
||||||
|
|
||||||
def assert_valid_http_request_message(
|
def assert_valid_http_request_message(
|
||||||
self, channel_message, request_method, request_path,
|
self, channel_message, request_method, request_path,
|
||||||
request_params=None, request_headers=None, request_body=None):
|
request_params=None, request_headers=None, request_body=None):
|
||||||
|
@ -35,22 +64,14 @@ class ASGITestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertTrue(channel_message)
|
self.assertTrue(channel_message)
|
||||||
|
|
||||||
# == General assertions about expected dictionary keys being present ==
|
self.assert_presence_of_message_keys(
|
||||||
message_keys = set(channel_message.keys())
|
channel_message.keys(),
|
||||||
required_message_keys = {
|
{'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers'},
|
||||||
'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers',
|
{'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'})
|
||||||
}
|
|
||||||
optional_message_keys = {
|
|
||||||
'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'
|
|
||||||
}
|
|
||||||
self.assertTrue(required_message_keys <= message_keys)
|
|
||||||
# Assert that no other keys are present
|
|
||||||
self.assertEqual(set(), message_keys - required_message_keys - optional_message_keys)
|
|
||||||
|
|
||||||
# == Assertions about required channel_message fields ==
|
# == Assertions about required channel_message fields ==
|
||||||
reply_channel = channel_message['reply_channel']
|
self.assert_valid_reply_channel(channel_message['reply_channel'])
|
||||||
self.assertIsInstance(reply_channel, six.text_type)
|
self.assert_valid_path(channel_message['path'], request_path)
|
||||||
self.assertTrue(reply_channel.startswith('test!'))
|
|
||||||
|
|
||||||
http_version = channel_message['http_version']
|
http_version = channel_message['http_version']
|
||||||
self.assertIsInstance(http_version, six.text_type)
|
self.assertIsInstance(http_version, six.text_type)
|
||||||
|
@ -61,12 +82,6 @@ class ASGITestCase(unittest.TestCase):
|
||||||
self.assertTrue(method.isupper())
|
self.assertTrue(method.isupper())
|
||||||
self.assertEqual(channel_message['method'], request_method)
|
self.assertEqual(channel_message['method'], request_method)
|
||||||
|
|
||||||
path = channel_message['path']
|
|
||||||
self.assertIsInstance(path, six.text_type)
|
|
||||||
self.assertEqual(path, request_path)
|
|
||||||
# Assert that it's already url decoded
|
|
||||||
self.assertEqual(path, parse.unquote(path))
|
|
||||||
|
|
||||||
query_string = channel_message['query_string']
|
query_string = channel_message['query_string']
|
||||||
# Assert that query_string is a byte string and still url encoded
|
# Assert that query_string is a byte string and still url encoded
|
||||||
self.assertIsInstance(query_string, six.binary_type)
|
self.assertIsInstance(query_string, six.binary_type)
|
||||||
|
@ -112,17 +127,11 @@ class ASGITestCase(unittest.TestCase):
|
||||||
|
|
||||||
client = channel_message.get('client')
|
client = channel_message.get('client')
|
||||||
if client is not None:
|
if client is not None:
|
||||||
client_host, client_port = client
|
self.assert_valid_address_and_port(channel_message['client'])
|
||||||
self.assertIsInstance(client_host, six.text_type)
|
|
||||||
self.assert_is_ip_address(client_host)
|
|
||||||
self.assertIsInstance(client_port, int)
|
|
||||||
|
|
||||||
server = channel_message.get('server')
|
server = channel_message.get('server')
|
||||||
if server is not None:
|
if server is not None:
|
||||||
server_host, server_port = channel_message['server']
|
self.assert_valid_address_and_port(channel_message['server'])
|
||||||
self.assertIsInstance(server_host, six.text_type)
|
|
||||||
self.assert_is_ip_address(server_host)
|
|
||||||
self.assertIsInstance(server_port, int)
|
|
||||||
|
|
||||||
def assert_valid_http_response_message(self, message, response):
|
def assert_valid_http_response_message(self, message, response):
|
||||||
self.assertTrue(message)
|
self.assertTrue(message)
|
||||||
|
@ -147,3 +156,87 @@ class ASGITestCase(unittest.TestCase):
|
||||||
# altered casing. The approach below does this well enough.
|
# altered casing. The approach below does this well enough.
|
||||||
self.assertIn(expected_header.lower(), response.lower())
|
self.assertIn(expected_header.lower(), response.lower())
|
||||||
self.assertIn(value.encode('ascii'), response)
|
self.assertIn(value.encode('ascii'), response)
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIWebSocketTestCase(ASGITestCaseBase):
|
||||||
|
"""
|
||||||
|
Test case with helpers for verifying WebSocket channel messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
def assert_websocket_upgrade(self, response, body=b'', expect_close=False):
|
||||||
|
self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
|
||||||
|
self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
|
||||||
|
self.assertIn(body, response)
|
||||||
|
self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
|
||||||
|
|
||||||
|
def assert_websocket_denied(self, response):
|
||||||
|
self.assertIn(b'HTTP/1.1 403', response)
|
||||||
|
|
||||||
|
def assert_valid_websocket_connect_message(
|
||||||
|
self, channel_message, request_path='/', request_params=None, request_headers=None):
|
||||||
|
"""
|
||||||
|
Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertTrue(channel_message)
|
||||||
|
|
||||||
|
self.assert_presence_of_message_keys(
|
||||||
|
channel_message.keys(),
|
||||||
|
{'reply_channel', 'path', 'headers', 'order'},
|
||||||
|
{'scheme', 'query_string', 'root_path', 'client', 'server'})
|
||||||
|
|
||||||
|
# == Assertions about required channel_message fields ==
|
||||||
|
self.assert_valid_reply_channel(channel_message['reply_channel'])
|
||||||
|
self.assert_valid_path(channel_message['path'], request_path)
|
||||||
|
|
||||||
|
order = channel_message['order']
|
||||||
|
self.assertIsInstance(order, int)
|
||||||
|
self.assertEqual(order, 0)
|
||||||
|
|
||||||
|
# Ordering of header names is not important, but the order of values for a header
|
||||||
|
# name is. To assert whether that order is kept, we transform the request
|
||||||
|
# headers and the channel message headers into a set
|
||||||
|
# {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
|
||||||
|
# Note that unlike for HTTP, Daphne never gives out individual header values; instead we
|
||||||
|
# get one string per header field with values separated by comma.
|
||||||
|
transformed_request_headers = defaultdict(list)
|
||||||
|
for name, value in (request_headers or []):
|
||||||
|
expected_name = name.lower().strip().encode('ascii')
|
||||||
|
expected_value = value.strip().encode('ascii')
|
||||||
|
transformed_request_headers[expected_name].append(expected_value)
|
||||||
|
final_request_headers = {
|
||||||
|
(name, b','.join(value)) for name, value in transformed_request_headers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Websockets carry a lot of additional header fields, so instead of verifying that
|
||||||
|
# headers look exactly like expected, we just check that the expected header fields
|
||||||
|
# and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
|
||||||
|
# and not tested for.
|
||||||
|
assert final_request_headers.issubset(set(channel_message['headers']))
|
||||||
|
|
||||||
|
# == Assertions about optional channel_message fields ==
|
||||||
|
scheme = channel_message.get('scheme')
|
||||||
|
if scheme:
|
||||||
|
self.assertIsInstance(scheme, six.text_type)
|
||||||
|
self.assertIn(scheme, ['ws', 'wss'])
|
||||||
|
|
||||||
|
query_string = channel_message.get('query_string')
|
||||||
|
if query_string:
|
||||||
|
# Assert that query_string is a byte string and still url encoded
|
||||||
|
# TODO: It's neither a byte string nor urlencoded
|
||||||
|
# Will fail until https://github.com/django/daphne/issues/110 is resolved.
|
||||||
|
#self.assertIsInstance(query_string, six.binary_type)
|
||||||
|
#self.assertEqual(query_string, parse.urlencode(request_params or []).encode('ascii'))
|
||||||
|
pass
|
||||||
|
|
||||||
|
root_path = channel_message.get('root_path')
|
||||||
|
if root_path is not None:
|
||||||
|
self.assertIsInstance(root_path, six.text_type)
|
||||||
|
|
||||||
|
client = channel_message.get('client')
|
||||||
|
if client is not None:
|
||||||
|
self.assert_valid_address_and_port(channel_message['client'])
|
||||||
|
|
||||||
|
server = channel_message.get('server')
|
||||||
|
if server is not None:
|
||||||
|
self.assert_valid_address_and_port(channel_message['server'])
|
||||||
|
|
|
@ -50,9 +50,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
# Tell main factory about it
|
# Tell main factory about it
|
||||||
self.main_factory.reply_protocols[self.reply_channel] = self
|
self.main_factory.reply_protocols[self.reply_channel] = self
|
||||||
# Get client address if possible
|
# Get client address if possible
|
||||||
if hasattr(self.transport.getPeer(), "host") and hasattr(self.transport.getPeer(), "port"):
|
peer = self.transport.getPeer()
|
||||||
self.client_addr = [self.transport.getPeer().host, self.transport.getPeer().port]
|
host = self.transport.getHost()
|
||||||
self.server_addr = [self.transport.getHost().host, self.transport.getHost().port]
|
if hasattr(peer, "host") and hasattr(peer, "port"):
|
||||||
|
self.client_addr = [six.text_type(peer.host), peer.port]
|
||||||
|
self.server_addr = [six.text_type(host.host), host.port]
|
||||||
else:
|
else:
|
||||||
self.client_addr = None
|
self.client_addr = None
|
||||||
self.server_addr = None
|
self.server_addr = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user