This commit is contained in:
Mathieu Dupuy 2025-12-09 11:23:04 +01:00 committed by GitHub
commit 616b17bb7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 181 additions and 261 deletions

View File

@ -5,11 +5,10 @@ from urllib.parse import unquote
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
from zope.interface import implementer
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
@ -24,132 +23,77 @@ class WebRequest(http.Request):
"""
error_template = (
"""
<html>
<head>
<title>%(title)s</title>
<style>
body { font-family: sans-serif; margin: 0; padding: 0; }
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
p { padding: 0 0 0.3em 20px; margin: 0; }
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
</style>
</head>
<body>
<h1>%(title)s</h1>
<p>%(body)s</p>
<footer>Daphne</footer>
</body>
</html>
""".replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
b"<!DOCTYPE html>"
b"<html>"
b"<head><title>%(status)d %(status_text)s</title></head>"
b"<body><h1>%(status)d %(status_text)s</h1>%(text)s</body>"
b"</html>"
)
def __init__(self, *args, **kwargs):
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.client_addr = None
self.server_addr = None
try:
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.server.protocol_connected(self)
except Exception:
logger.error(traceback.format_exc())
raise
### Twisted progress callbacks
self.client_scheme = None
# Build the client address
if self.transport:
peer = self.transport.getPeer()
host = self.transport.getHost()
# Always set scheme if we have a transport
self.client_scheme = (
"https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
)
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
# Hook up request parsing
self.socket_opened = time.time()
self.server.protocol_connected(self)
@inlineCallbacks
def process(self):
"""
Called when all headers have been received and we can start processing content.
"""
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if forwarded
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
{name: value for name, value in self.requestHeaders.getAllRawHeaders()},
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for maximum request body size
if self.server.request_max_size:
self.channel.maxData = self.server.request_max_size
# Get query string
self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
try:
self.request_start = time.time()
# Validate header names.
for name, _ in self.requestHeaders.getAllRawHeaders():
if not HEADER_NAME_RE.fullmatch(name):
self.basic_error(400, b"Bad Request", "Invalid header name")
return
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
# client.host and host.host are byte strings in Python 2, but spec
# requires unicode string.
self.client_addr = [str(self.client.host), self.client.port]
self.server_addr = [str(self.host.host), self.host.port]
self.client_scheme = "https" if self.isSecure() else "http"
# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
try:
self.query_string.decode("ascii")
except UnicodeDecodeError:
self.basic_error(400, b"Bad Request", "Invalid query string")
return
# Is it WebSocket? IS IT?!
# Process WebSocket requests via HTTP upgrade
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
# Pass request to WebSocketResource for handling
self.server.ws_resource.render_GET(self)
# The WebSocketResource will handle the rest of the connection
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing()
# Boring old HTTP.
# Don't continue with HTTP processing
return
# Handle normal HTTP requests
else:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
@ -339,9 +283,9 @@ class WebRequest(http.Request):
"""
Returns the time since the start of the request.
"""
if not hasattr(self, "request_start"):
if not hasattr(self, "socket_opened"):
return 0
return time.time() - self.request_start
return time.time() - self.socket_opened
def basic_error(self, status, status_text, body):
"""
@ -357,13 +301,12 @@ class WebRequest(http.Request):
self.handle_reply(
{
"type": "http.response.body",
"body": (
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}
).encode("utf8"),
"body": self.error_template
% {
"status": status,
"status_text": status_text,
"text": body,
},
}
)

View File

@ -37,6 +37,7 @@ from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
from twisted.web.websocket import WebSocketResource
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
@ -77,6 +78,7 @@ class Server:
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size
self.request_max_size = None # No limit by default
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
@ -99,12 +101,14 @@ class Server:
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, server=self.server_name)
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout,
)
# Create WebSocket factory
self.ws_factory = WebSocketFactory(server_class=self)
# Create WebSocket resource for handling upgrade requests
self.ws_resource = WebSocketResource(self.ws_factory)
# Configure logging
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo(

View File

@ -3,19 +3,15 @@ import time
import traceback
from urllib.parse import unquote
from autobahn.twisted.websocket import (
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer
from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
class WebSocketProtocol(WebSocketServerProtocol):
class WebSocketProtocol(TwistedWebSocketProtocol):
"""
Protocol which supports WebSockets and forwards incoming messages to
the websocket channels.
@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
def onConnect(self, request):
def __init__(self, factory, request):
self.factory = factory
self.transport = None
self.server = self.factory.server_class
self.server.protocol_connected(self)
self.request = request
self.protocol_to_accept = None
self.root_path = self.server.root_path
self.socket_opened = time.time()
self.last_ping = time.time()
self.client_addr = None
self.server_addr = None
self.clean_headers = []
self.handshake_deferred = None
self.path = None
self.root_path = None
self.application_queue = None
self.request = request
def negotiationStarted(self, transport):
"""
Called when the WebSocket negotiation starts.
"""
self.transport = transport
self.server.protocol_connected(self)
self.protocol_to_accept = None
self.root_path = self.server.root_path
try:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
for name, value in request.headers.items():
name = name.encode("ascii")
for name, value in self.request.requestHeaders.getAllRawHeaders():
name = name.lower()
# Prevent CVE-2015-0219
if b"_" in name:
continue
if name.lower() == b"daphne-root-path":
self.root_path = unquote(value)
if name == b"daphne-root-path":
self.root_path = unquote(value[0].decode("ascii"))
else:
self.clean_headers.append((name.lower(), value.encode("latin1")))
self.clean_headers.append((name, value[0]))
# Get client address if possible
peer = self.transport.getPeer()
host = self.transport.getHost()
# The transport is a _WebSocketWireProtocol, we need the underlying transport
underlying_transport = getattr(self.transport, "transport", self.transport)
peer = underlying_transport.getPeer()
host = underlying_transport.getHost()
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_proto_header,
self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
subprotocols = [
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Extract query string
query_string = b""
if b"?" in self.request.uri:
query_string = self.request.uri.split(b"?", 1)[1]
# Get the path
self.path = self.request.path
# Make new application instance with scope
self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(
self.server.create_application,
self,
@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"raw_path": self.path,
"root_path": self.root_path,
"headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol
"query_string": query_string,
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(traceback.format_exc())
raise
# Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred()
return self.handshake_deferred
def applicationCreateWorked(self, application_queue):
"""
Called when the background thread has successfully made the application
@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"websocket",
"connecting",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(failure)
return failure
### Twisted event handling
def onOpen(self):
# Send news that this channel is open
def negotiationFinished(self):
"""
Called when the WebSocket negotiation is finished.
"""
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action(
"websocket",
"connected",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def onMessage(self, payload, isBinary):
def textMessageReceived(self, message):
"""
Called when a text message is received.
"""
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
if isBinary:
self.application_queue.put_nowait(
{"type": "websocket.receive", "bytes": payload}
)
else:
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": payload.decode("utf8")}
)
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": message}
)
def onClose(self, wasClean, code, reason):
def bytesMessageReceived(self, data):
"""
Called when Twisted closes the socket.
Called when a binary message is received.
"""
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data})
def connectionLost(self, reason):
"""
Called when the WebSocket connection is lost.
"""
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
{"type": "websocket.disconnect", "code": 1000} # Default close code
)
self.server.log_action(
"websocket",
"disconnected",
{
"path": self.request.path,
"path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def pongReceived(self, payload):
"""
Called when a pong frame is received in response to a ping.
"""
self.last_ping = time.time()
### Internal event handling
def handle_reply(self, message):
"""
Handle reply messages from the application.
"""
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None))
# Accept is handled by WebSocketResource in Twisted 25
# Our protocol is already established at this point
pass
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else:
self.serverClose(code=message.get("code", None))
self.transport.loseConnection(code=message.get("code", 1000))
elif message["type"] == "websocket.send":
if self.state == self.STATE_CONNECTING:
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
% (message,)
% (self.client_addr,)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
self.transport.sendBytesMessage(message["bytes"])
if message.get("text", None):
self.serverSend(message["text"], False)
self.transport.sendTextMessage(message["text"])
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action(
"websocket",
"rejected",
{
"path": self.request.path,
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def serverSend(self, content, binary=False):
"""
Server-side channel message to send a message.
"""
if self.state == self.STATE_CONNECTING:
self.serverAccept()
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary:
self.sendMessage(content, binary)
else:
self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=None):
"""
Server-side channel message to close the socket
"""
code = 1000 if code is None else code
self.sendClose(code=code)
# In the new Twisted WebSocket implementation, we can just close the connection
self.transport.loseConnection(code=1011) # Internal server error
### Utils
@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
self.serverClose()
self.transport.loseConnection(code=1000)
# Ping check
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.server.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if hasattr(self, "transport") and self.transport:
if (time.time() - self.last_ping) > self.server.ping_interval:
self._sendAutoPing()
self.transport.ping()
self.last_ping = time.time()
def __hash__(self):
@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol):
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
class WebSocketFactory(WebSocketServerFactory):
class WebSocketFactory:
"""
Factory subclass that remembers what the "main"
factory is, so WebSocket protocols can access it
to get reply ID info.
Factory for WebSocket protocols.
"""
protocol = WebSocketProtocol
def __init__(self, server_class, *args, **kwargs):
def __init__(self, server_class):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr):
def buildProtocol(self, request):
"""
Builds protocol instances. We use this to inject the factory object into the protocol.
Builds a new WebSocket protocol.
"""
try:
protocol = super().buildProtocol(addr)
protocol.factory = self
protocol = WebSocketProtocol(self, request)
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())

View File

@ -24,7 +24,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP",
]
dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"]
[project.optional-dependencies]
tests = [