From 231bfb7c4ebf4e68f3ea08f9fba27e9d47dc1ea6 Mon Sep 17 00:00:00 2001 From: niekas Date: Mon, 14 Sep 2020 13:33:13 +0300 Subject: [PATCH] Allow to accept websocket extensions Also accept `permessage-deflate`, `permessage-bzip2` and `permessage-snappy` compression extensions by default if client requests for them. Compression/decompression of the messages is taken care of by `autobahn` package. --- daphne/server.py | 25 +++++++++++++++++++ daphne/ws_protocol.py | 13 +++++++--- tests/test_websocket.py | 54 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/daphne/server.py b/daphne/server.py index 5ede808..7b65079 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -23,6 +23,7 @@ import logging import time from concurrent.futures import CancelledError +from autobahn.websocket.compress import PERMESSAGE_COMPRESSION_EXTENSION as EXTENSIONS from twisted.internet import defer, reactor from twisted.internet.endpoints import serverFromString from twisted.logger import STDLibLogObserver, globalLogBeginner @@ -44,6 +45,11 @@ class Server(object): http_timeout=None, websocket_timeout=86400, websocket_connect_timeout=20, + websocket_permessage_compression_extensions=[ + "permessage-deflate", + "permessage-bzip2", + "permessage-snappy", + ], ping_interval=20, ping_timeout=30, root_path="", @@ -73,6 +79,9 @@ class Server(object): self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout + self.websocket_permessage_compression_extensions = ( + websocket_permessage_compression_extensions + ) self.application_close_timeout = application_close_timeout self.root_path = root_path self.verbosity = verbosity @@ -94,6 +103,7 @@ class Server(object): autoPingTimeout=self.ping_timeout, allowNullOrigin=True, openHandshakeTimeout=self.websocket_handshake_timeout, + perMessageCompressionAccept=self.accept_permessage_compression_extension, ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere @@ -246,6 +256,21 @@ class Server(object): ) ) + def accept_permessage_compression_extension(self, offers): + """ + Accepts websocket compression extension as required by `autobahn` package. + """ + for offer in offers: + for ext in self.websocket_permessage_compression_extensions: + if ext in EXTENSIONS and isinstance(offer, EXTENSIONS[ext]["Offer"]): + return EXTENSIONS[ext]["OfferAccept"](offer) + elif ext not in EXTENSIONS: + logger.warning( + "Compression extension %s could not be accepted. " + "It is not supported or a dependency is missing.", + ext, + ) + ### Utility def application_checker(self): diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 1962450..47c19e6 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -182,7 +182,10 @@ class WebSocketProtocol(WebSocketServerProtocol): if "type" not in message: raise ValueError("Message has no type defined") if message["type"] == "websocket.accept": - self.serverAccept(message.get("subprotocol", None)) + self.serverAccept( + message.get("subprotocol", None), message.get("headers", None) + ) + elif message["type"] == "websocket.close": if self.state == self.STATE_CONNECTING: self.serverReject() @@ -214,11 +217,15 @@ class WebSocketProtocol(WebSocketServerProtocol): else: self.sendCloseFrame(code=1011) - def serverAccept(self, subprotocol=None): + def serverAccept(self, subprotocol=None, headers=None): """ Called when we get a message saying to accept the connection. """ - self.handshake_deferred.callback(subprotocol) + if headers is None: + self.handshake_deferred.callback(subprotocol) + else: + headers_dict = {key.decode(): value.decode() for key, value in headers} + self.handshake_deferred.callback((subprotocol, headers_dict)) del self.handshake_deferred logger.debug("WebSocket %s accepted by application", self.client_addr) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 9ec2c0d..637296c 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -132,6 +132,60 @@ class TestWebsocket(DaphneTestCase): self.assert_valid_websocket_scope(scope, subprotocols=subprotocols) self.assert_valid_websocket_connect_message(messages[0]) + def test_accept_permessage_deflate_extension(self): + """ + Tests that permessage-deflate extension is successfuly accepted + by underlying `autobahn` package. + """ + + headers = [ + ( + b"Sec-WebSocket-Extensions", + b"permessage-deflate; client_max_window_bits", + ), + ] + + with DaphneTestingInstance() as test_app: + test_app.add_send_messages( + [ + { + "type": "websocket.accept", + } + ] + ) + + sock, subprotocol = self.websocket_handshake( + test_app, + headers=headers, + ) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + + def test_accept_custom_extension(self): + """ + Tests that custom headers can be accpeted during handshake. + """ + with DaphneTestingInstance() as test_app: + test_app.add_send_messages( + [ + { + "type": "websocket.accept", + "headers": [(b"Sec-WebSocket-Extensions", b"custom-extension")], + } + ] + ) + + sock, subprotocol = self.websocket_handshake( + test_app, + headers=[ + (b"Sec-WebSocket-Extensions", b"custom-extension"), + ], + ) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + def test_xff(self): """ Tests that X-Forwarded-For headers get parsed right