Allow to accept websocket extensions

Also accept `permessage-deflate`, `permessage-bzip2` and `permessage-snappy`
compression extensions by default if client requests for them.
Compression/decompression of the messages is taken care of by `autobahn`
package.
This commit is contained in:
niekas 2020-09-14 13:33:13 +03:00
parent 9838a173d7
commit 231bfb7c4e
3 changed files with 89 additions and 3 deletions

View File

@ -23,6 +23,7 @@ import logging
import time
from concurrent.futures import CancelledError
from autobahn.websocket.compress import PERMESSAGE_COMPRESSION_EXTENSION as EXTENSIONS
from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
@ -44,6 +45,11 @@ class Server(object):
http_timeout=None,
websocket_timeout=86400,
websocket_connect_timeout=20,
websocket_permessage_compression_extensions=[
"permessage-deflate",
"permessage-bzip2",
"permessage-snappy",
],
ping_interval=20,
ping_timeout=30,
root_path="",
@ -73,6 +79,9 @@ class Server(object):
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout
self.websocket_permessage_compression_extensions = (
websocket_permessage_compression_extensions
)
self.application_close_timeout = application_close_timeout
self.root_path = root_path
self.verbosity = verbosity
@ -94,6 +103,7 @@ class Server(object):
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout,
perMessageCompressionAccept=self.accept_permessage_compression_extension,
)
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
@ -246,6 +256,21 @@ class Server(object):
)
)
def accept_permessage_compression_extension(self, offers):
"""
Accepts websocket compression extension as required by `autobahn` package.
"""
for offer in offers:
for ext in self.websocket_permessage_compression_extensions:
if ext in EXTENSIONS and isinstance(offer, EXTENSIONS[ext]["Offer"]):
return EXTENSIONS[ext]["OfferAccept"](offer)
elif ext not in EXTENSIONS:
logger.warning(
"Compression extension %s could not be accepted. "
"It is not supported or a dependency is missing.",
ext,
)
### Utility
def application_checker(self):

View File

@ -182,7 +182,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None))
self.serverAccept(
message.get("subprotocol", None), message.get("headers", None)
)
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
@ -214,11 +217,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
def serverAccept(self, subprotocol=None, headers=None):
"""
Called when we get a message saying to accept the connection.
"""
if headers is None:
self.handshake_deferred.callback(subprotocol)
else:
headers_dict = {key.decode(): value.decode() for key, value in headers}
self.handshake_deferred.callback((subprotocol, headers_dict))
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)

View File

@ -132,6 +132,60 @@ class TestWebsocket(DaphneTestCase):
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
self.assert_valid_websocket_connect_message(messages[0])
def test_accept_permessage_deflate_extension(self):
"""
Tests that permessage-deflate extension is successfuly accepted
by underlying `autobahn` package.
"""
headers = [
(
b"Sec-WebSocket-Extensions",
b"permessage-deflate; client_max_window_bits",
),
]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[
{
"type": "websocket.accept",
}
]
)
sock, subprotocol = self.websocket_handshake(
test_app,
headers=headers,
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
def test_accept_custom_extension(self):
"""
Tests that custom headers can be accpeted during handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[
{
"type": "websocket.accept",
"headers": [(b"Sec-WebSocket-Extensions", b"custom-extension")],
}
]
)
sock, subprotocol = self.websocket_handshake(
test_app,
headers=[
(b"Sec-WebSocket-Extensions", b"custom-extension"),
],
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
def test_xff(self):
"""
Tests that X-Forwarded-For headers get parsed right