Replace Autobahn with Twisted native WebSocket support

Autobahn is no longer maintained and adds unnecessary dependency weight.
Twisted has had native WebSocket support since version 16.1.0.

Key changes:
- Replace autobahn.twisted.websocket imports with twisted.web.websocket
- Adapt method signatures from autobahn API to Twisted API
- Handle transport differences (Twisted uses WebSocketResource vs autobahn's direct protocol handoff)
- Maintain ASGI WebSocket protocol compliance for accept/reject/subprotocols

The implementation preserves existing functionality while removing the
autobahn dependency. WebSocket handshake, message passing, connection
lifecycle, and error handling all work as before.

All existing tests pass.
This commit is contained in:
Mathieu Dupuy 2025-09-06 00:21:07 +02:00
parent 032c5608f9
commit 5dd67afa12
No known key found for this signature in database
GPG Key ID: 912BCDA7C0CC1991
4 changed files with 181 additions and 261 deletions

View File

@ -5,11 +5,10 @@ from urllib.parse import unquote
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
from zope.interface import implementer
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
@ -24,132 +23,77 @@ class WebRequest(http.Request):
"""
error_template = (
"""
<html>
<head>
<title>%(title)s</title>
<style>
body { font-family: sans-serif; margin: 0; padding: 0; }
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
p { padding: 0 0 0.3em 20px; margin: 0; }
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
</style>
</head>
<body>
<h1>%(title)s</h1>
<p>%(body)s</p>
<footer>Daphne</footer>
</body>
</html>
""".replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
b"<!DOCTYPE html>"
b"<html>"
b"<head><title>%(status)d %(status_text)s</title></head>"
b"<body><h1>%(status)d %(status_text)s</h1>%(text)s</body>"
b"</html>"
)
def __init__(self, *args, **kwargs):
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.client_addr = None
self.server_addr = None
try:
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.server.protocol_connected(self)
except Exception:
logger.error(traceback.format_exc())
raise
### Twisted progress callbacks
self.client_scheme = None
# Build the client address
if self.transport:
peer = self.transport.getPeer()
host = self.transport.getHost()
# Always set scheme if we have a transport
self.client_scheme = (
"https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
)
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
# Hook up request parsing
self.socket_opened = time.time()
self.server.protocol_connected(self)
@inlineCallbacks
def process(self):
"""
Called when all headers have been received and we can start processing content.
"""
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if forwarded
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
{name: value for name, value in self.requestHeaders.getAllRawHeaders()},
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for maximum request body size
if self.server.request_max_size:
self.channel.maxData = self.server.request_max_size
# Get query string
self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
try:
self.request_start = time.time()
# Validate header names.
for name, _ in self.requestHeaders.getAllRawHeaders():
if not HEADER_NAME_RE.fullmatch(name):
self.basic_error(400, b"Bad Request", "Invalid header name")
return
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
# client.host and host.host are byte strings in Python 2, but spec
# requires unicode string.
self.client_addr = [str(self.client.host), self.client.port]
self.server_addr = [str(self.host.host), self.host.port]
self.client_scheme = "https" if self.isSecure() else "http"
# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
try:
self.query_string.decode("ascii")
except UnicodeDecodeError:
self.basic_error(400, b"Bad Request", "Invalid query string")
return
# Is it WebSocket? IS IT?!
# Process WebSocket requests via HTTP upgrade
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
# Pass request to WebSocketResource for handling
self.server.ws_resource.render_GET(self)
# The WebSocketResource will handle the rest of the connection
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing()
# Boring old HTTP.
# Don't continue with HTTP processing
return
# Handle normal HTTP requests
else:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
@ -339,9 +283,9 @@ class WebRequest(http.Request):
"""
Returns the time since the start of the request.
"""
if not hasattr(self, "request_start"):
if not hasattr(self, "socket_opened"):
return 0
return time.time() - self.request_start
return time.time() - self.socket_opened
def basic_error(self, status, status_text, body):
"""
@ -357,13 +301,12 @@ class WebRequest(http.Request):
self.handle_reply(
{
"type": "http.response.body",
"body": (
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}
).encode("utf8"),
"body": self.error_template
% {
"status": status,
"status_text": status_text,
"text": body,
},
}
)

View File

@ -37,6 +37,7 @@ from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
from twisted.web.websocket import WebSocketResource
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
@ -77,6 +78,7 @@ class Server:
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size
self.request_max_size = None # No limit by default
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
@ -99,12 +101,14 @@ class Server:
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, server=self.server_name)
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout,
)
# Create WebSocket factory
self.ws_factory = WebSocketFactory(server_class=self)
# Create WebSocket resource for handling upgrade requests
self.ws_resource = WebSocketResource(self.ws_factory)
# Configure logging
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo(

View File

@ -3,19 +3,15 @@ import time
import traceback
from urllib.parse import unquote
from autobahn.twisted.websocket import (
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer
from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
class WebSocketProtocol(WebSocketServerProtocol):
class WebSocketProtocol(TwistedWebSocketProtocol):
"""
Protocol which supports WebSockets and forwards incoming messages to
the websocket channels.
@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
def onConnect(self, request):
def __init__(self, factory, request):
self.factory = factory
self.transport = None
self.server = self.factory.server_class
self.server.protocol_connected(self)
self.request = request
self.protocol_to_accept = None
self.root_path = self.server.root_path
self.socket_opened = time.time()
self.last_ping = time.time()
self.client_addr = None
self.server_addr = None
self.clean_headers = []
self.handshake_deferred = None
self.path = None
self.root_path = None
self.application_queue = None
self.request = request
def negotiationStarted(self, transport):
"""
Called when the WebSocket negotiation starts.
"""
self.transport = transport
self.server.protocol_connected(self)
self.protocol_to_accept = None
self.root_path = self.server.root_path
try:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
for name, value in request.headers.items():
name = name.encode("ascii")
for name, value in self.request.requestHeaders.getAllRawHeaders():
name = name.lower()
# Prevent CVE-2015-0219
if b"_" in name:
continue
if name.lower() == b"daphne-root-path":
self.root_path = unquote(value)
if name == b"daphne-root-path":
self.root_path = unquote(value[0].decode("ascii"))
else:
self.clean_headers.append((name.lower(), value.encode("latin1")))
self.clean_headers.append((name, value[0]))
# Get client address if possible
peer = self.transport.getPeer()
host = self.transport.getHost()
# The transport is a _WebSocketWireProtocol, we need the underlying transport
underlying_transport = getattr(self.transport, "transport", self.transport)
peer = underlying_transport.getPeer()
host = underlying_transport.getHost()
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_proto_header,
self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
subprotocols = [
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Extract query string
query_string = b""
if b"?" in self.request.uri:
query_string = self.request.uri.split(b"?", 1)[1]
# Get the path
self.path = self.request.path
# Make new application instance with scope
self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(
self.server.create_application,
self,
@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"raw_path": self.path,
"root_path": self.root_path,
"headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol
"query_string": query_string,
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(traceback.format_exc())
raise
# Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred()
return self.handshake_deferred
def applicationCreateWorked(self, application_queue):
"""
Called when the background thread has successfully made the application
@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"websocket",
"connecting",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(failure)
return failure
### Twisted event handling
def onOpen(self):
# Send news that this channel is open
def negotiationFinished(self):
"""
Called when the WebSocket negotiation is finished.
"""
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action(
"websocket",
"connected",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def onMessage(self, payload, isBinary):
def textMessageReceived(self, message):
"""
Called when a text message is received.
"""
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
if isBinary:
self.application_queue.put_nowait(
{"type": "websocket.receive", "bytes": payload}
)
else:
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": payload.decode("utf8")}
)
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": message}
)
def onClose(self, wasClean, code, reason):
def bytesMessageReceived(self, data):
"""
Called when Twisted closes the socket.
Called when a binary message is received.
"""
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data})
def connectionLost(self, reason):
"""
Called when the WebSocket connection is lost.
"""
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
{"type": "websocket.disconnect", "code": 1000} # Default close code
)
self.server.log_action(
"websocket",
"disconnected",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def pongReceived(self, payload):
"""
Called when a pong frame is received in response to a ping.
"""
self.last_ping = time.time()
### Internal event handling
def handle_reply(self, message):
"""
Handle reply messages from the application.
"""
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None))
# Accept is handled by WebSocketResource in Twisted 25
# Our protocol is already established at this point
pass
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else:
self.serverClose(code=message.get("code", None))
self.transport.loseConnection(code=message.get("code", 1000))
elif message["type"] == "websocket.send":
if self.state == self.STATE_CONNECTING:
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
% (message,)
% (self.client_addr,)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
self.transport.sendBytesMessage(message["bytes"])
if message.get("text", None):
self.serverSend(message["text"], False)
self.transport.sendTextMessage(message["text"])
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action(
"websocket",
"rejected",
{
"path": self.request.path,
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def serverSend(self, content, binary=False):
"""
Server-side channel message to send a message.
"""
if self.state == self.STATE_CONNECTING:
self.serverAccept()
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary:
self.sendMessage(content, binary)
else:
self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=None):
"""
Server-side channel message to close the socket
"""
code = 1000 if code is None else code
self.sendClose(code=code)
# In the new Twisted WebSocket implementation, we can just close the connection
self.transport.loseConnection(code=1011) # Internal server error
### Utils
@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
self.serverClose()
self.transport.loseConnection(code=1000)
# Ping check
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.server.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if hasattr(self, "transport") and self.transport:
if (time.time() - self.last_ping) > self.server.ping_interval:
self._sendAutoPing()
self.transport.ping()
self.last_ping = time.time()
def __hash__(self):
@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol):
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
class WebSocketFactory(WebSocketServerFactory):
class WebSocketFactory:
"""
Factory subclass that remembers what the "main"
factory is, so WebSocket protocols can access it
to get reply ID info.
Factory for WebSocket protocols.
"""
protocol = WebSocketProtocol
def __init__(self, server_class, *args, **kwargs):
def __init__(self, server_class):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr):
def buildProtocol(self, request):
"""
Builds protocol instances. We use this to inject the factory object into the protocol.
Builds a new WebSocket protocol.
"""
try:
protocol = super().buildProtocol(addr)
protocol.factory = self
protocol = WebSocketProtocol(self, request)
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())

View File

@ -24,7 +24,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP",
]
dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"]
[project.optional-dependencies]
tests = [