mirror of
				https://github.com/django/daphne.git
				synced 2025-11-04 01:27:33 +03:00 
			
		
		
		
	Websockets test and unicode fix for Python 2 (#111)
* Python 2 fix for host address
This is a copy of
57051a48cd
for the Websocket protocol.
In Python 2, Twisted returns a byte string for the host address, while
the spec requires a unicode string. A simple cast gives us consistency.
* Test suite for websocket tests
This commit
* introduces some new helpers to test the Websocket protocol
* renames the old ASGITestCase class to ASGIHTTPTestCase, and
  introduces a test case for testing Websockets
* moves some helper methods that are shared between HTTP and Websockets
  into a mutual base class
* uses the new helpers to simplfiy the existing tests
* and adds a couple new tests.
			
			
This commit is contained in:
		
							parent
							
								
									bd03fabce6
								
							
						
					
					
						commit
						2bcec3fe94
					
				| 
						 | 
					@ -84,6 +84,19 @@ def _build_request(method, path, params=None, headers=None, body=None):
 | 
				
			||||||
    return request
 | 
					    return request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_websocket_upgrade(path, params, headers):
 | 
				
			||||||
 | 
					    ws_headers = [
 | 
				
			||||||
 | 
					        ('Host', 'somewhere.com'),
 | 
				
			||||||
 | 
					        ('Upgrade', 'websocket'),
 | 
				
			||||||
 | 
					        ('Connection', 'Upgrade'),
 | 
				
			||||||
 | 
					        ('Sec-WebSocket-Key', 'x3JJHMbDL1EzLkh9GBhXDw=='),
 | 
				
			||||||
 | 
					        ('Sec-WebSocket-Protocol', 'chat, superchat'),
 | 
				
			||||||
 | 
					        ('Sec-WebSocket-Version', '13'),
 | 
				
			||||||
 | 
					        ('Origin', 'http://example.com')
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    return _build_request('GET', path, params, headers=headers + ws_headers, body=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def header_line(name, value):
 | 
					def header_line(name, value):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a header name and value, returns the line to use in a HTTP request or response.
 | 
					    Given a header name and value, returns the line to use in a HTTP request or response.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,12 +17,17 @@ def http_method():
 | 
				
			||||||
    return strategies.sampled_from(HTTP_METHODS)
 | 
					    return strategies.sampled_from(HTTP_METHODS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _http_path_portion():
 | 
				
			||||||
 | 
					    alphabet = string.ascii_letters + string.digits + '-._~'
 | 
				
			||||||
 | 
					    return strategies.text(min_size=1, average_size=10, max_size=128, alphabet=alphabet)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def http_path():
 | 
					def http_path():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Returns a URL path (not encoded).
 | 
					    Returns a URL path (not encoded).
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    alphabet = string.ascii_letters + string.digits + '-._~/'
 | 
					    return strategies.lists(
 | 
				
			||||||
    return strategies.text(min_size=0, max_size=255, alphabet=alphabet).map(lambda s: '/' + s)
 | 
					        _http_path_portion(), min_size=0, max_size=10).map(lambda s: '/' + '/'.join(s))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def http_body():
 | 
					def http_body():
 | 
				
			||||||
| 
						 | 
					@ -33,6 +38,10 @@ def http_body():
 | 
				
			||||||
    return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
 | 
					    return strategies.text(alphabet=string.printable, min_size=0, average_size=600, max_size=1500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def binary_payload():
 | 
				
			||||||
 | 
					    return strategies.binary(min_size=0, average_size=600, max_size=1500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def valid_bidi(value):
 | 
					def valid_bidi(value):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Rejects strings which nonsensical Unicode text direction flags.
 | 
					    Rejects strings which nonsensical Unicode text direction flags.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,7 @@ from daphne.tests import testcases, http_strategies
 | 
				
			||||||
from daphne.tests.factories import message_for_request, content_length_header
 | 
					from daphne.tests.factories import message_for_request, content_length_header
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestHTTPRequestSpec(testcases.ASGITestCase):
 | 
					class TestHTTPRequestSpec(testcases.ASGIHTTPTestCase):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Tests which try to pour the HTTP request section of the ASGI spec into code.
 | 
					    Tests which try to pour the HTTP request section of the ASGI spec into code.
 | 
				
			||||||
    The heavy lifting is done by the assert_valid_http_request_message function,
 | 
					    The heavy lifting is done by the assert_valid_http_request_message function,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,7 @@ from daphne.http_protocol import HTTPFactory
 | 
				
			||||||
from . import factories, http_strategies, testcases
 | 
					from . import factories, http_strategies, testcases
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestHTTPResponseSpec(testcases.ASGITestCase):
 | 
					class TestHTTPResponseSpec(testcases.ASGIHTTPTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_minimal_response(self):
 | 
					    def test_minimal_response(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,27 +1,171 @@
 | 
				
			||||||
# coding: utf8
 | 
					# coding: utf8
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
from unittest import TestCase
 | 
					
 | 
				
			||||||
from asgiref.inmemory import ChannelLayer
 | 
					from hypothesis import assume, given, strategies
 | 
				
			||||||
from twisted.test import proto_helpers
 | 
					from twisted.test import proto_helpers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from asgiref.inmemory import ChannelLayer
 | 
				
			||||||
from daphne.http_protocol import HTTPFactory
 | 
					from daphne.http_protocol import HTTPFactory
 | 
				
			||||||
 | 
					from daphne.tests import http_strategies, testcases, factories
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestWebSocketProtocol(TestCase):
 | 
					class WebSocketConnection(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Tests that the WebSocket protocol class correcly generates and parses messages.
 | 
					    Helper class that makes it easier to test Dahpne's WebSocket support.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.last_message = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.channel_layer = ChannelLayer()
 | 
					        self.channel_layer = ChannelLayer()
 | 
				
			||||||
        self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
 | 
					        self.factory = HTTPFactory(self.channel_layer, send_channel="test!")
 | 
				
			||||||
        self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
 | 
					        self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
 | 
				
			||||||
        self.tr = proto_helpers.StringTransport()
 | 
					        self.transport = proto_helpers.StringTransport()
 | 
				
			||||||
        self.proto.makeConnection(self.tr)
 | 
					        self.proto.makeConnection(self.transport)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def receive(self, request):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Low-level method to let Daphne handle HTTP/WebSocket data
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.proto.dataReceived(request)
 | 
				
			||||||
 | 
					        _, self.last_message = self.channel_layer.receive(['websocket.connect'])
 | 
				
			||||||
 | 
					        return self.last_message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self, content):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Method to respond with a channel message
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.last_message is None:
 | 
				
			||||||
 | 
					            # Auto-connect for convenience.
 | 
				
			||||||
 | 
					            self.connect()
 | 
				
			||||||
 | 
					        self.factory.dispatch_reply(self.last_message['reply_channel'], content)
 | 
				
			||||||
 | 
					        response = self.transport.value()
 | 
				
			||||||
 | 
					        self.transport.clear()
 | 
				
			||||||
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connect(self, path='/', params=None, headers=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        High-level method to perform the WebSocket handshake
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        request = factories.build_websocket_upgrade(path, params, headers or [])
 | 
				
			||||||
 | 
					        message = self.receive(request)
 | 
				
			||||||
 | 
					        return message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestHandshake(testcases.ASGIWebSocketTestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Tests for the WebSocket handshake
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_minimal(self):
 | 
				
			||||||
 | 
					        message = WebSocketConnection().connect()
 | 
				
			||||||
 | 
					        self.assert_valid_websocket_connect_message(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(
 | 
				
			||||||
 | 
					        path=http_strategies.http_path(),
 | 
				
			||||||
 | 
					        params=http_strategies.query_params(),
 | 
				
			||||||
 | 
					        headers=http_strategies.headers(),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_connection(self, path, params, headers):
 | 
				
			||||||
 | 
					        message = WebSocketConnection().connect(path, params, headers)
 | 
				
			||||||
 | 
					        self.assert_valid_websocket_connect_message(message, path, params, headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestSendCloseAccept(testcases.ASGIWebSocketTestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Tests that, essentially, try to translate the send/close/accept section of the spec into code.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_empty_accept(self):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': True})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(text=http_strategies.http_body())
 | 
				
			||||||
 | 
					    def test_accept_and_text(self, text):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': True, 'text': text})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, text.encode('ascii'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(data=http_strategies.binary_payload())
 | 
				
			||||||
 | 
					    def test_accept_and_bytes(self, data):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': True, 'bytes': data})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_accept_false(self):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': False})
 | 
				
			||||||
 | 
					        self.assert_websocket_denied(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_accept_false_with_text(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests that even if text is given, the connection is denied.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We can't easily use Hypothesis to generate data for this test because it's
 | 
				
			||||||
 | 
					        hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        text = 'foobar'
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': False, 'text': text})
 | 
				
			||||||
 | 
					        self.assert_websocket_denied(response)
 | 
				
			||||||
 | 
					        self.assertNotIn(text.encode('ascii'), response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_accept_false_with_bytes(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Tests that even if data is given, the connection is denied.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We can't easily use Hypothesis to generate data for this test because it's
 | 
				
			||||||
 | 
					        hard to detect absence of the body if e.g. Hypothesis would generate a 'GET'
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        data = b'foobar'
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'accept': False, 'bytes': data})
 | 
				
			||||||
 | 
					        self.assert_websocket_denied(response)
 | 
				
			||||||
 | 
					        self.assertNotIn(data, response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(text=http_strategies.http_body())
 | 
				
			||||||
 | 
					    def test_just_text(self, text):
 | 
				
			||||||
 | 
					        assume(len(text) > 0)
 | 
				
			||||||
 | 
					        # If content is sent, accept=True is implied.
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'text': text})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, text.encode('ascii'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(data=http_strategies.binary_payload())
 | 
				
			||||||
 | 
					    def test_just_bytes(self, data):
 | 
				
			||||||
 | 
					        assume(len(data) > 0)
 | 
				
			||||||
 | 
					        # If content is sent, accept=True is implied.
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'bytes': data})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_close_boolean(self):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'close': True})
 | 
				
			||||||
 | 
					        self.assert_websocket_denied(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(number=strategies.integers(min_value=1))
 | 
				
			||||||
 | 
					    def test_close_integer(self, number):
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'close': number})
 | 
				
			||||||
 | 
					        self.assert_websocket_denied(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(text=http_strategies.http_body())
 | 
				
			||||||
 | 
					    def test_close_with_text(self, text):
 | 
				
			||||||
 | 
					        assume(len(text) > 0)
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'close': True, 'text': text})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, text.encode('ascii'), expect_close=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @given(data=http_strategies.binary_payload())
 | 
				
			||||||
 | 
					    def test_close_with_data(self, data):
 | 
				
			||||||
 | 
					        assume(len(data) > 0)
 | 
				
			||||||
 | 
					        response = WebSocketConnection().send({'close': True, 'bytes': data})
 | 
				
			||||||
 | 
					        self.assert_websocket_upgrade(response, data, expect_close=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestWebSocketProtocol(testcases.ASGIWebSocketTestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Tests that the WebSocket protocol class correctly generates and parses messages.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.connection = WebSocketConnection()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_basic(self):
 | 
					    def test_basic(self):
 | 
				
			||||||
        # Send a simple request to the protocol
 | 
					        # Send a simple request to the protocol and get the resulting message off
 | 
				
			||||||
        self.proto.dataReceived(
 | 
					        # of the channel layer.
 | 
				
			||||||
 | 
					        message = self.connection.receive(
 | 
				
			||||||
            b"GET /chat HTTP/1.1\r\n"
 | 
					            b"GET /chat HTTP/1.1\r\n"
 | 
				
			||||||
            b"Host: somewhere.com\r\n"
 | 
					            b"Host: somewhere.com\r\n"
 | 
				
			||||||
            b"Upgrade: websocket\r\n"
 | 
					            b"Upgrade: websocket\r\n"
 | 
				
			||||||
| 
						 | 
					@ -32,8 +176,6 @@ class TestWebSocketProtocol(TestCase):
 | 
				
			||||||
            b"Origin: http://example.com\r\n"
 | 
					            b"Origin: http://example.com\r\n"
 | 
				
			||||||
            b"\r\n"
 | 
					            b"\r\n"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["websocket.connect"])
 | 
					 | 
				
			||||||
        self.assertEqual(message['path'], "/chat")
 | 
					        self.assertEqual(message['path'], "/chat")
 | 
				
			||||||
        self.assertEqual(message['query_string'], "")
 | 
					        self.assertEqual(message['query_string'], "")
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
| 
						 | 
					@ -46,53 +188,26 @@ class TestWebSocketProtocol(TestCase):
 | 
				
			||||||
             (b'sec-websocket-version', b'13'),
 | 
					             (b'sec-websocket-version', b'13'),
 | 
				
			||||||
             (b'upgrade', b'websocket')]
 | 
					             (b'upgrade', b'websocket')]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertTrue(message['reply_channel'].startswith("test!"))
 | 
					        self.assert_valid_websocket_connect_message(message, '/chat')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Accept the connection
 | 
					        # Accept the connection
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'accept': True})
 | 
				
			||||||
            message['reply_channel'],
 | 
					        self.assert_websocket_upgrade(response)
 | 
				
			||||||
            {'accept': True}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Make sure that we get a 101 Switching Protocols back
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
 | 
					 | 
				
			||||||
        self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
					 | 
				
			||||||
        self.tr.clear()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Send some text
 | 
					        # Send some text
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'text': "Hello World!"})
 | 
				
			||||||
            message['reply_channel'],
 | 
					 | 
				
			||||||
            {'text': "Hello World!"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertEqual(response, b"\x81\x0cHello World!")
 | 
					        self.assertEqual(response, b"\x81\x0cHello World!")
 | 
				
			||||||
        self.tr.clear()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Send some bytes
 | 
					        # Send some bytes
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'bytes': b"\xaa\xbb\xcc\xdd"})
 | 
				
			||||||
            message['reply_channel'],
 | 
					 | 
				
			||||||
            {'bytes': b"\xaa\xbb\xcc\xdd"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
 | 
					        self.assertEqual(response, b"\x82\x04\xaa\xbb\xcc\xdd")
 | 
				
			||||||
        self.tr.clear()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Close the connection
 | 
					        # Close the connection
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'close': True})
 | 
				
			||||||
            message['reply_channel'],
 | 
					 | 
				
			||||||
            {'close': True}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertEqual(response, b"\x88\x02\x03\xe8")
 | 
					        self.assertEqual(response, b"\x88\x02\x03\xe8")
 | 
				
			||||||
        self.tr.clear()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_connection_with_file_origin_is_accepted(self):
 | 
					    def test_connection_with_file_origin_is_accepted(self):
 | 
				
			||||||
        # Send a simple request to the protocol
 | 
					        message = self.connection.receive(
 | 
				
			||||||
        self.proto.dataReceived(
 | 
					 | 
				
			||||||
            b"GET /chat HTTP/1.1\r\n"
 | 
					            b"GET /chat HTTP/1.1\r\n"
 | 
				
			||||||
            b"Host: somewhere.com\r\n"
 | 
					            b"Host: somewhere.com\r\n"
 | 
				
			||||||
            b"Upgrade: websocket\r\n"
 | 
					            b"Upgrade: websocket\r\n"
 | 
				
			||||||
| 
						 | 
					@ -103,26 +218,15 @@ class TestWebSocketProtocol(TestCase):
 | 
				
			||||||
            b"Origin: file://\r\n"
 | 
					            b"Origin: file://\r\n"
 | 
				
			||||||
            b"\r\n"
 | 
					            b"\r\n"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["websocket.connect"])
 | 
					 | 
				
			||||||
        self.assertIn((b'origin', b'file://'), message['headers'])
 | 
					        self.assertIn((b'origin', b'file://'), message['headers'])
 | 
				
			||||||
        self.assertTrue(message['reply_channel'].startswith("test!"))
 | 
					        self.assert_valid_websocket_connect_message(message, '/chat')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Accept the connection
 | 
					        # Accept the connection
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'accept': True})
 | 
				
			||||||
            message['reply_channel'],
 | 
					        self.assert_websocket_upgrade(response)
 | 
				
			||||||
            {'accept': True}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Make sure that we get a 101 Switching Protocols back
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
 | 
					 | 
				
			||||||
        self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_connection_with_no_origin_is_accepted(self):
 | 
					    def test_connection_with_no_origin_is_accepted(self):
 | 
				
			||||||
        # Send a simple request to the protocol
 | 
					        message = self.connection.receive(
 | 
				
			||||||
        self.proto.dataReceived(
 | 
					 | 
				
			||||||
            b"GET /chat HTTP/1.1\r\n"
 | 
					            b"GET /chat HTTP/1.1\r\n"
 | 
				
			||||||
            b"Host: somewhere.com\r\n"
 | 
					            b"Host: somewhere.com\r\n"
 | 
				
			||||||
            b"Upgrade: websocket\r\n"
 | 
					            b"Upgrade: websocket\r\n"
 | 
				
			||||||
| 
						 | 
					@ -133,18 +237,9 @@ class TestWebSocketProtocol(TestCase):
 | 
				
			||||||
            b"\r\n"
 | 
					            b"\r\n"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Get the resulting message off of the channel layer
 | 
					 | 
				
			||||||
        _, message = self.channel_layer.receive(["websocket.connect"])
 | 
					 | 
				
			||||||
        self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']])
 | 
					        self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']])
 | 
				
			||||||
        self.assertTrue(message['reply_channel'].startswith("test!"))
 | 
					        self.assert_valid_websocket_connect_message(message, '/chat')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Accept the connection
 | 
					        # Accept the connection
 | 
				
			||||||
        self.factory.dispatch_reply(
 | 
					        response = self.connection.send({'accept': True})
 | 
				
			||||||
            message['reply_channel'],
 | 
					        self.assert_websocket_upgrade(response)
 | 
				
			||||||
            {'accept': True}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Make sure that we get a 101 Switching Protocols back
 | 
					 | 
				
			||||||
        response = self.tr.value()
 | 
					 | 
				
			||||||
        self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response)
 | 
					 | 
				
			||||||
        self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,9 +12,9 @@ import unittest
 | 
				
			||||||
from . import factories
 | 
					from . import factories
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ASGITestCase(unittest.TestCase):
 | 
					class ASGITestCaseBase(unittest.TestCase):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Test case with helpers for ASGI message verification
 | 
					    Base class for our test classes which contains shared method.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def assert_is_ip_address(self, address):
 | 
					    def assert_is_ip_address(self, address):
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,35 @@ class ASGITestCase(unittest.TestCase):
 | 
				
			||||||
        except socket.error:
 | 
					        except socket.error:
 | 
				
			||||||
            self.fail("'%s' is not a valid IP address." % address)
 | 
					            self.fail("'%s' is not a valid IP address." % address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_presence_of_message_keys(self, keys, required_keys, optional_keys):
 | 
				
			||||||
 | 
					        present_keys = set(keys)
 | 
				
			||||||
 | 
					        self.assertTrue(required_keys <= present_keys)
 | 
				
			||||||
 | 
					        # Assert that no other keys are present
 | 
				
			||||||
 | 
					        self.assertEqual(set(), present_keys - required_keys - optional_keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_reply_channel(self, reply_channel):
 | 
				
			||||||
 | 
					        self.assertIsInstance(reply_channel, six.text_type)
 | 
				
			||||||
 | 
					        # The reply channel is decided by the server.
 | 
				
			||||||
 | 
					        self.assertTrue(reply_channel.startswith('test!'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_path(self, path, request_path):
 | 
				
			||||||
 | 
					        self.assertIsInstance(path, six.text_type)
 | 
				
			||||||
 | 
					        self.assertEqual(path, request_path)
 | 
				
			||||||
 | 
					        # Assert that it's already url decoded
 | 
				
			||||||
 | 
					        self.assertEqual(path, parse.unquote(path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_address_and_port(self, host):
 | 
				
			||||||
 | 
					        address, port = host
 | 
				
			||||||
 | 
					        self.assertIsInstance(address, six.text_type)
 | 
				
			||||||
 | 
					        self.assert_is_ip_address(address)
 | 
				
			||||||
 | 
					        self.assertIsInstance(port, int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ASGIHTTPTestCase(ASGITestCaseBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Test case with helpers for verifying HTTP channel messages
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def assert_valid_http_request_message(
 | 
					    def assert_valid_http_request_message(
 | 
				
			||||||
            self, channel_message, request_method, request_path,
 | 
					            self, channel_message, request_method, request_path,
 | 
				
			||||||
            request_params=None, request_headers=None, request_body=None):
 | 
					            request_params=None, request_headers=None, request_body=None):
 | 
				
			||||||
| 
						 | 
					@ -35,22 +64,14 @@ class ASGITestCase(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertTrue(channel_message)
 | 
					        self.assertTrue(channel_message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # == General assertions about expected dictionary keys being present ==
 | 
					        self.assert_presence_of_message_keys(
 | 
				
			||||||
        message_keys = set(channel_message.keys())
 | 
					            channel_message.keys(),
 | 
				
			||||||
        required_message_keys = {
 | 
					            {'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers'},
 | 
				
			||||||
            'reply_channel', 'http_version', 'method', 'path', 'query_string', 'headers',
 | 
					            {'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'})
 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        optional_message_keys = {
 | 
					 | 
				
			||||||
            'scheme', 'root_path', 'body', 'body_channel', 'client', 'server'
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        self.assertTrue(required_message_keys <= message_keys)
 | 
					 | 
				
			||||||
        # Assert that no other keys are present
 | 
					 | 
				
			||||||
        self.assertEqual(set(), message_keys - required_message_keys - optional_message_keys)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # == Assertions about required channel_message fields ==
 | 
					        # == Assertions about required channel_message fields ==
 | 
				
			||||||
        reply_channel = channel_message['reply_channel']
 | 
					        self.assert_valid_reply_channel(channel_message['reply_channel'])
 | 
				
			||||||
        self.assertIsInstance(reply_channel, six.text_type)
 | 
					        self.assert_valid_path(channel_message['path'], request_path)
 | 
				
			||||||
        self.assertTrue(reply_channel.startswith('test!'))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        http_version = channel_message['http_version']
 | 
					        http_version = channel_message['http_version']
 | 
				
			||||||
        self.assertIsInstance(http_version, six.text_type)
 | 
					        self.assertIsInstance(http_version, six.text_type)
 | 
				
			||||||
| 
						 | 
					@ -61,12 +82,6 @@ class ASGITestCase(unittest.TestCase):
 | 
				
			||||||
        self.assertTrue(method.isupper())
 | 
					        self.assertTrue(method.isupper())
 | 
				
			||||||
        self.assertEqual(channel_message['method'], request_method)
 | 
					        self.assertEqual(channel_message['method'], request_method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        path = channel_message['path']
 | 
					 | 
				
			||||||
        self.assertIsInstance(path, six.text_type)
 | 
					 | 
				
			||||||
        self.assertEqual(path, request_path)
 | 
					 | 
				
			||||||
        # Assert that it's already url decoded
 | 
					 | 
				
			||||||
        self.assertEqual(path, parse.unquote(path))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_string = channel_message['query_string']
 | 
					        query_string = channel_message['query_string']
 | 
				
			||||||
        # Assert that query_string is a byte string and still url encoded
 | 
					        # Assert that query_string is a byte string and still url encoded
 | 
				
			||||||
        self.assertIsInstance(query_string, six.binary_type)
 | 
					        self.assertIsInstance(query_string, six.binary_type)
 | 
				
			||||||
| 
						 | 
					@ -112,17 +127,11 @@ class ASGITestCase(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        client = channel_message.get('client')
 | 
					        client = channel_message.get('client')
 | 
				
			||||||
        if client is not None:
 | 
					        if client is not None:
 | 
				
			||||||
            client_host, client_port = client
 | 
					            self.assert_valid_address_and_port(channel_message['client'])
 | 
				
			||||||
            self.assertIsInstance(client_host, six.text_type)
 | 
					 | 
				
			||||||
            self.assert_is_ip_address(client_host)
 | 
					 | 
				
			||||||
            self.assertIsInstance(client_port, int)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        server = channel_message.get('server')
 | 
					        server = channel_message.get('server')
 | 
				
			||||||
        if server is not None:
 | 
					        if server is not None:
 | 
				
			||||||
            server_host, server_port = channel_message['server']
 | 
					            self.assert_valid_address_and_port(channel_message['server'])
 | 
				
			||||||
            self.assertIsInstance(server_host, six.text_type)
 | 
					 | 
				
			||||||
            self.assert_is_ip_address(server_host)
 | 
					 | 
				
			||||||
            self.assertIsInstance(server_port, int)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def assert_valid_http_response_message(self, message, response):
 | 
					    def assert_valid_http_response_message(self, message, response):
 | 
				
			||||||
        self.assertTrue(message)
 | 
					        self.assertTrue(message)
 | 
				
			||||||
| 
						 | 
					@ -147,3 +156,87 @@ class ASGITestCase(unittest.TestCase):
 | 
				
			||||||
                # altered casing. The approach below does this well enough.
 | 
					                # altered casing. The approach below does this well enough.
 | 
				
			||||||
                self.assertIn(expected_header.lower(), response.lower())
 | 
					                self.assertIn(expected_header.lower(), response.lower())
 | 
				
			||||||
                self.assertIn(value.encode('ascii'), response)
 | 
					                self.assertIn(value.encode('ascii'), response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ASGIWebSocketTestCase(ASGITestCaseBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Test case with helpers for verifying WebSocket channel messages
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_websocket_upgrade(self, response, body=b'', expect_close=False):
 | 
				
			||||||
 | 
					        self.assertIn(b"HTTP/1.1 101 Switching Protocols", response)
 | 
				
			||||||
 | 
					        self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)
 | 
				
			||||||
 | 
					        self.assertIn(body, response)
 | 
				
			||||||
 | 
					        self.assertEqual(expect_close, response.endswith(b"\x88\x02\x03\xe8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_websocket_denied(self, response):
 | 
				
			||||||
 | 
					        self.assertIn(b'HTTP/1.1 403', response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_valid_websocket_connect_message(
 | 
				
			||||||
 | 
					            self, channel_message, request_path='/', request_params=None, request_headers=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Asserts that a given channel message conforms to the HTTP request section of the ASGI spec.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertTrue(channel_message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assert_presence_of_message_keys(
 | 
				
			||||||
 | 
					            channel_message.keys(),
 | 
				
			||||||
 | 
					            {'reply_channel', 'path', 'headers', 'order'},
 | 
				
			||||||
 | 
					            {'scheme', 'query_string', 'root_path', 'client', 'server'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # == Assertions about required channel_message fields ==
 | 
				
			||||||
 | 
					        self.assert_valid_reply_channel(channel_message['reply_channel'])
 | 
				
			||||||
 | 
					        self.assert_valid_path(channel_message['path'], request_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        order = channel_message['order']
 | 
				
			||||||
 | 
					        self.assertIsInstance(order, int)
 | 
				
			||||||
 | 
					        self.assertEqual(order, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Ordering of header names is not important, but the order of values for a header
 | 
				
			||||||
 | 
					        # name is. To assert whether that order is kept, we transform the request
 | 
				
			||||||
 | 
					        # headers and the channel message headers into a set
 | 
				
			||||||
 | 
					        # {('name1': 'value1,value2'), ('name2': 'value3')} and check if they're equal.
 | 
				
			||||||
 | 
					        # Note that unlike for HTTP, Daphne never gives out individual header values; instead we
 | 
				
			||||||
 | 
					        # get one string per header field with values separated by comma.
 | 
				
			||||||
 | 
					        transformed_request_headers = defaultdict(list)
 | 
				
			||||||
 | 
					        for name, value in (request_headers or []):
 | 
				
			||||||
 | 
					            expected_name = name.lower().strip().encode('ascii')
 | 
				
			||||||
 | 
					            expected_value = value.strip().encode('ascii')
 | 
				
			||||||
 | 
					            transformed_request_headers[expected_name].append(expected_value)
 | 
				
			||||||
 | 
					        final_request_headers = {
 | 
				
			||||||
 | 
					            (name, b','.join(value)) for name, value in transformed_request_headers.items()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Websockets carry a lot of additional header fields, so instead of verifying that
 | 
				
			||||||
 | 
					        # headers look exactly like expected, we just check that the expected header fields
 | 
				
			||||||
 | 
					        # and values are present - additional header fields (e.g. Sec-WebSocket-Key) are allowed
 | 
				
			||||||
 | 
					        # and not tested for.
 | 
				
			||||||
 | 
					        assert final_request_headers.issubset(set(channel_message['headers']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # == Assertions about optional channel_message fields ==
 | 
				
			||||||
 | 
					        scheme = channel_message.get('scheme')
 | 
				
			||||||
 | 
					        if scheme:
 | 
				
			||||||
 | 
					            self.assertIsInstance(scheme, six.text_type)
 | 
				
			||||||
 | 
					            self.assertIn(scheme, ['ws', 'wss'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_string = channel_message.get('query_string')
 | 
				
			||||||
 | 
					        if query_string:
 | 
				
			||||||
 | 
					            # Assert that query_string is a byte string and still url encoded
 | 
				
			||||||
 | 
					            # TODO: It's neither a byte string nor urlencoded
 | 
				
			||||||
 | 
					            # Will fail until https://github.com/django/daphne/issues/110 is resolved.
 | 
				
			||||||
 | 
					            #self.assertIsInstance(query_string, six.binary_type)
 | 
				
			||||||
 | 
					            #self.assertEqual(query_string, parse.urlencode(request_params or []).encode('ascii'))
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        root_path = channel_message.get('root_path')
 | 
				
			||||||
 | 
					        if root_path is not None:
 | 
				
			||||||
 | 
					            self.assertIsInstance(root_path, six.text_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        client = channel_message.get('client')
 | 
				
			||||||
 | 
					        if client is not None:
 | 
				
			||||||
 | 
					            self.assert_valid_address_and_port(channel_message['client'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        server = channel_message.get('server')
 | 
				
			||||||
 | 
					        if server is not None:
 | 
				
			||||||
 | 
					            self.assert_valid_address_and_port(channel_message['server'])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,9 +50,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
 | 
				
			||||||
            # Tell main factory about it
 | 
					            # Tell main factory about it
 | 
				
			||||||
            self.main_factory.reply_protocols[self.reply_channel] = self
 | 
					            self.main_factory.reply_protocols[self.reply_channel] = self
 | 
				
			||||||
            # Get client address if possible
 | 
					            # Get client address if possible
 | 
				
			||||||
            if hasattr(self.transport.getPeer(), "host") and hasattr(self.transport.getPeer(), "port"):
 | 
					            peer = self.transport.getPeer()
 | 
				
			||||||
                self.client_addr = [self.transport.getPeer().host, self.transport.getPeer().port]
 | 
					            host = self.transport.getHost()
 | 
				
			||||||
                self.server_addr = [self.transport.getHost().host, self.transport.getHost().port]
 | 
					            if hasattr(peer, "host") and hasattr(peer, "port"):
 | 
				
			||||||
 | 
					                self.client_addr = [six.text_type(peer.host), peer.port]
 | 
				
			||||||
 | 
					                self.server_addr = [six.text_type(host.host), host.port]
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.client_addr = None
 | 
					                self.client_addr = None
 | 
				
			||||||
                self.server_addr = None
 | 
					                self.server_addr = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user