mirror of
https://github.com/django/daphne.git
synced 2024-11-24 17:03:42 +03:00
Allow to accept websocket extensions
Also accept `permessage-deflate`, `permessage-bzip2` and `permessage-snappy` compression extensions by default if client requests for them. Compression/decompression of the messages is taken care of by `autobahn` package.
This commit is contained in:
parent
9838a173d7
commit
231bfb7c4e
|
@ -23,6 +23,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import CancelledError
|
from concurrent.futures import CancelledError
|
||||||
|
|
||||||
|
from autobahn.websocket.compress import PERMESSAGE_COMPRESSION_EXTENSION as EXTENSIONS
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.endpoints import serverFromString
|
from twisted.internet.endpoints import serverFromString
|
||||||
from twisted.logger import STDLibLogObserver, globalLogBeginner
|
from twisted.logger import STDLibLogObserver, globalLogBeginner
|
||||||
|
@ -44,6 +45,11 @@ class Server(object):
|
||||||
http_timeout=None,
|
http_timeout=None,
|
||||||
websocket_timeout=86400,
|
websocket_timeout=86400,
|
||||||
websocket_connect_timeout=20,
|
websocket_connect_timeout=20,
|
||||||
|
websocket_permessage_compression_extensions=[
|
||||||
|
"permessage-deflate",
|
||||||
|
"permessage-bzip2",
|
||||||
|
"permessage-snappy",
|
||||||
|
],
|
||||||
ping_interval=20,
|
ping_interval=20,
|
||||||
ping_timeout=30,
|
ping_timeout=30,
|
||||||
root_path="",
|
root_path="",
|
||||||
|
@ -73,6 +79,9 @@ class Server(object):
|
||||||
self.websocket_timeout = websocket_timeout
|
self.websocket_timeout = websocket_timeout
|
||||||
self.websocket_connect_timeout = websocket_connect_timeout
|
self.websocket_connect_timeout = websocket_connect_timeout
|
||||||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||||
|
self.websocket_permessage_compression_extensions = (
|
||||||
|
websocket_permessage_compression_extensions
|
||||||
|
)
|
||||||
self.application_close_timeout = application_close_timeout
|
self.application_close_timeout = application_close_timeout
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
|
@ -94,6 +103,7 @@ class Server(object):
|
||||||
autoPingTimeout=self.ping_timeout,
|
autoPingTimeout=self.ping_timeout,
|
||||||
allowNullOrigin=True,
|
allowNullOrigin=True,
|
||||||
openHandshakeTimeout=self.websocket_handshake_timeout,
|
openHandshakeTimeout=self.websocket_handshake_timeout,
|
||||||
|
perMessageCompressionAccept=self.accept_permessage_compression_extension,
|
||||||
)
|
)
|
||||||
if self.verbosity <= 1:
|
if self.verbosity <= 1:
|
||||||
# Redirect the Twisted log to nowhere
|
# Redirect the Twisted log to nowhere
|
||||||
|
@ -246,6 +256,21 @@ class Server(object):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def accept_permessage_compression_extension(self, offers):
|
||||||
|
"""
|
||||||
|
Accepts websocket compression extension as required by `autobahn` package.
|
||||||
|
"""
|
||||||
|
for offer in offers:
|
||||||
|
for ext in self.websocket_permessage_compression_extensions:
|
||||||
|
if ext in EXTENSIONS and isinstance(offer, EXTENSIONS[ext]["Offer"]):
|
||||||
|
return EXTENSIONS[ext]["OfferAccept"](offer)
|
||||||
|
elif ext not in EXTENSIONS:
|
||||||
|
logger.warning(
|
||||||
|
"Compression extension %s could not be accepted. "
|
||||||
|
"It is not supported or a dependency is missing.",
|
||||||
|
ext,
|
||||||
|
)
|
||||||
|
|
||||||
### Utility
|
### Utility
|
||||||
|
|
||||||
def application_checker(self):
|
def application_checker(self):
|
||||||
|
|
|
@ -182,7 +182,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
if "type" not in message:
|
if "type" not in message:
|
||||||
raise ValueError("Message has no type defined")
|
raise ValueError("Message has no type defined")
|
||||||
if message["type"] == "websocket.accept":
|
if message["type"] == "websocket.accept":
|
||||||
self.serverAccept(message.get("subprotocol", None))
|
self.serverAccept(
|
||||||
|
message.get("subprotocol", None), message.get("headers", None)
|
||||||
|
)
|
||||||
|
|
||||||
elif message["type"] == "websocket.close":
|
elif message["type"] == "websocket.close":
|
||||||
if self.state == self.STATE_CONNECTING:
|
if self.state == self.STATE_CONNECTING:
|
||||||
self.serverReject()
|
self.serverReject()
|
||||||
|
@ -214,11 +217,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
else:
|
else:
|
||||||
self.sendCloseFrame(code=1011)
|
self.sendCloseFrame(code=1011)
|
||||||
|
|
||||||
def serverAccept(self, subprotocol=None):
|
def serverAccept(self, subprotocol=None, headers=None):
|
||||||
"""
|
"""
|
||||||
Called when we get a message saying to accept the connection.
|
Called when we get a message saying to accept the connection.
|
||||||
"""
|
"""
|
||||||
|
if headers is None:
|
||||||
self.handshake_deferred.callback(subprotocol)
|
self.handshake_deferred.callback(subprotocol)
|
||||||
|
else:
|
||||||
|
headers_dict = {key.decode(): value.decode() for key, value in headers}
|
||||||
|
self.handshake_deferred.callback((subprotocol, headers_dict))
|
||||||
del self.handshake_deferred
|
del self.handshake_deferred
|
||||||
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||||
|
|
||||||
|
|
|
@ -132,6 +132,60 @@ class TestWebsocket(DaphneTestCase):
|
||||||
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
|
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
|
||||||
self.assert_valid_websocket_connect_message(messages[0])
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
def test_accept_permessage_deflate_extension(self):
|
||||||
|
"""
|
||||||
|
Tests that permessage-deflate extension is successfuly accepted
|
||||||
|
by underlying `autobahn` package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
headers = [
|
||||||
|
(
|
||||||
|
b"Sec-WebSocket-Extensions",
|
||||||
|
b"permessage-deflate; client_max_window_bits",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
sock, subprotocol = self.websocket_handshake(
|
||||||
|
test_app,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
# Validate the scope and messages we got
|
||||||
|
scope, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
def test_accept_custom_extension(self):
|
||||||
|
"""
|
||||||
|
Tests that custom headers can be accpeted during handshake.
|
||||||
|
"""
|
||||||
|
with DaphneTestingInstance() as test_app:
|
||||||
|
test_app.add_send_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
"headers": [(b"Sec-WebSocket-Extensions", b"custom-extension")],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
sock, subprotocol = self.websocket_handshake(
|
||||||
|
test_app,
|
||||||
|
headers=[
|
||||||
|
(b"Sec-WebSocket-Extensions", b"custom-extension"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Validate the scope and messages we got
|
||||||
|
scope, messages = test_app.get_received()
|
||||||
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
def test_xff(self):
|
def test_xff(self):
|
||||||
"""
|
"""
|
||||||
Tests that X-Forwarded-For headers get parsed right
|
Tests that X-Forwarded-For headers get parsed right
|
||||||
|
|
Loading…
Reference in New Issue
Block a user