Allow to accept websocket extensions

Also accept `permessage-deflate`, `permessage-bzip2` and `permessage-snappy`
compression extensions by default if client requests for them.
Compression/decompression of the messages is taken care of by `autobahn`
package.
This commit is contained in:
niekas 2020-09-14 13:33:13 +03:00
parent 9838a173d7
commit 231bfb7c4e
3 changed files with 89 additions and 3 deletions

View File

@ -23,6 +23,7 @@ import logging
import time import time
from concurrent.futures import CancelledError from concurrent.futures import CancelledError
from autobahn.websocket.compress import PERMESSAGE_COMPRESSION_EXTENSION as EXTENSIONS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner from twisted.logger import STDLibLogObserver, globalLogBeginner
@ -44,6 +45,11 @@ class Server(object):
http_timeout=None, http_timeout=None,
websocket_timeout=86400, websocket_timeout=86400,
websocket_connect_timeout=20, websocket_connect_timeout=20,
websocket_permessage_compression_extensions=[
"permessage-deflate",
"permessage-bzip2",
"permessage-snappy",
],
ping_interval=20, ping_interval=20,
ping_timeout=30, ping_timeout=30,
root_path="", root_path="",
@ -73,6 +79,9 @@ class Server(object):
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout self.websocket_handshake_timeout = websocket_handshake_timeout
self.websocket_permessage_compression_extensions = (
websocket_permessage_compression_extensions
)
self.application_close_timeout = application_close_timeout self.application_close_timeout = application_close_timeout
self.root_path = root_path self.root_path = root_path
self.verbosity = verbosity self.verbosity = verbosity
@ -94,6 +103,7 @@ class Server(object):
autoPingTimeout=self.ping_timeout, autoPingTimeout=self.ping_timeout,
allowNullOrigin=True, allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout, openHandshakeTimeout=self.websocket_handshake_timeout,
perMessageCompressionAccept=self.accept_permessage_compression_extension,
) )
if self.verbosity <= 1: if self.verbosity <= 1:
# Redirect the Twisted log to nowhere # Redirect the Twisted log to nowhere
@ -246,6 +256,21 @@ class Server(object):
) )
) )
def accept_permessage_compression_extension(self, offers):
"""
Accepts websocket compression extension as required by `autobahn` package.
"""
for offer in offers:
for ext in self.websocket_permessage_compression_extensions:
if ext in EXTENSIONS and isinstance(offer, EXTENSIONS[ext]["Offer"]):
return EXTENSIONS[ext]["OfferAccept"](offer)
elif ext not in EXTENSIONS:
logger.warning(
"Compression extension %s could not be accepted. "
"It is not supported or a dependency is missing.",
ext,
)
### Utility ### Utility
def application_checker(self): def application_checker(self):

View File

@ -182,7 +182,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
if "type" not in message: if "type" not in message:
raise ValueError("Message has no type defined") raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept": if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None)) self.serverAccept(
message.get("subprotocol", None), message.get("headers", None)
)
elif message["type"] == "websocket.close": elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING: if self.state == self.STATE_CONNECTING:
self.serverReject() self.serverReject()
@ -214,11 +217,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
else: else:
self.sendCloseFrame(code=1011) self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None): def serverAccept(self, subprotocol=None, headers=None):
""" """
Called when we get a message saying to accept the connection. Called when we get a message saying to accept the connection.
""" """
self.handshake_deferred.callback(subprotocol) if headers is None:
self.handshake_deferred.callback(subprotocol)
else:
headers_dict = {key.decode(): value.decode() for key, value in headers}
self.handshake_deferred.callback((subprotocol, headers_dict))
del self.handshake_deferred del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr) logger.debug("WebSocket %s accepted by application", self.client_addr)

View File

@ -132,6 +132,60 @@ class TestWebsocket(DaphneTestCase):
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols) self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
self.assert_valid_websocket_connect_message(messages[0]) self.assert_valid_websocket_connect_message(messages[0])
def test_accept_permessage_deflate_extension(self):
"""
Tests that permessage-deflate extension is successfuly accepted
by underlying `autobahn` package.
"""
headers = [
(
b"Sec-WebSocket-Extensions",
b"permessage-deflate; client_max_window_bits",
),
]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[
{
"type": "websocket.accept",
}
]
)
sock, subprotocol = self.websocket_handshake(
test_app,
headers=headers,
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
def test_accept_custom_extension(self):
"""
Tests that custom headers can be accpeted during handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[
{
"type": "websocket.accept",
"headers": [(b"Sec-WebSocket-Extensions", b"custom-extension")],
}
]
)
sock, subprotocol = self.websocket_handshake(
test_app,
headers=[
(b"Sec-WebSocket-Extensions", b"custom-extension"),
],
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
def test_xff(self): def test_xff(self):
""" """
Tests that X-Forwarded-For headers get parsed right Tests that X-Forwarded-For headers get parsed right