diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index a30cbed..2bc1903 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -304,7 +304,10 @@ class HTTPFactory(http.HTTPFactory): self.reply_protocols = {} # Make a factory for WebSocket protocols self.ws_factory = WebSocketFactory(self, protocols=ws_protocols) - self.ws_factory.setProtocolOptions(autoPingTimeout=ping_timeout) + self.ws_factory.setProtocolOptions( + autoPingTimeout=ping_timeout, + allowNullOrigin=True, + ) self.ws_factory.protocol = WebSocketProtocol self.ws_factory.reply_protocols = self.reply_protocols self.root_path = root_path diff --git a/daphne/tests/test_ws.py b/daphne/tests/test_ws.py index aea2155..01e64d1 100644 --- a/daphne/tests/test_ws.py +++ b/daphne/tests/test_ws.py @@ -89,3 +89,62 @@ class TestWebSocketProtocol(TestCase): response = self.tr.value() self.assertEqual(response, b"\x88\x02\x03\xe8") self.tr.clear() + + def test_connection_with_file_origin_is_accepted(self): + # Send a simple request to the protocol + self.proto.dataReceived( + b"GET /chat HTTP/1.1\r\n" + b"Host: somewhere.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Origin: file://\r\n" + b"\r\n" + ) + + # Get the resulting message off of the channel layer + _, message = self.channel_layer.receive_many(["websocket.connect"]) + self.assertIn((b'origin', b'file://'), message['headers']) + self.assertTrue(message['reply_channel'].startswith("websocket.send!")) + + # Accept the connection + self.factory.dispatch_reply( + message['reply_channel'], + {'accept': True} + ) + + # Make sure that we get a 101 Switching Protocols back + response = self.tr.value() + self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response) + self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response) + + def test_connection_with_no_origin_is_accepted(self): + # Send a simple request to the protocol + self.proto.dataReceived( + b"GET /chat HTTP/1.1\r\n" + b"Host: somewhere.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" + ) + + # Get the resulting message off of the channel layer + _, message = self.channel_layer.receive_many(["websocket.connect"]) + self.assertNotIn(b'origin', [header_tuple[0] for header_tuple in message['headers']]) + self.assertTrue(message['reply_channel'].startswith("websocket.send!")) + + # Accept the connection + self.factory.dispatch_reply( + message['reply_channel'], + {'accept': True} + ) + + # Make sure that we get a 101 Switching Protocols back + response = self.tr.value() + self.assertIn(b"HTTP/1.1 101 Switching Protocols\r\n", response) + self.assertIn(b"Sec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n", response)