diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py
index e9eba96..b554adc 100755
--- a/daphne/http_protocol.py
+++ b/daphne/http_protocol.py
@@ -5,11 +5,10 @@ from urllib.parse import unquote
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
-from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
from zope.interface import implementer
-from .utils import HEADER_NAME_RE, parse_x_forwarded_for
+from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
@@ -24,132 +23,77 @@ class WebRequest(http.Request):
"""
error_template = (
- """
-
-
- %(title)s
-
-
-
- %(title)s
- %(body)s
-
-
-
- """.replace(
- "\n", ""
- )
- .replace(" ", " ")
- .replace(" ", " ")
- .replace(" ", " ")
- ) # Shorten it a bit, bytes wise
+ b""
+ b""
+ b"%(status)d %(status_text)s"
+ b"%(status)d %(status_text)s
%(text)s"
+ b""
+ )
def __init__(self, *args, **kwargs):
+ http.Request.__init__(self, *args, **kwargs)
+ # Easy server link
+ self.server = self.channel.factory.server
+ self.application_queue = None
+ self._response_started = False
self.client_addr = None
self.server_addr = None
- try:
- http.Request.__init__(self, *args, **kwargs)
- # Easy server link
- self.server = self.channel.factory.server
- self.application_queue = None
- self._response_started = False
- self.server.protocol_connected(self)
- except Exception:
- logger.error(traceback.format_exc())
- raise
-
- ### Twisted progress callbacks
+ self.client_scheme = None
+ # Build the client address
+ if self.transport:
+ peer = self.transport.getPeer()
+ host = self.transport.getHost()
+ # Always set scheme if we have a transport
+ self.client_scheme = (
+ "https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
+ )
+ if hasattr(peer, "host") and hasattr(peer, "port"):
+ self.client_addr = [str(peer.host), peer.port]
+ self.server_addr = [str(host.host), host.port]
+ # Get upgrade header
+ upgrade_header = None
+ if self.requestHeaders.hasHeader(b"Upgrade"):
+ upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
+ self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
+ # Hook up request parsing
+ self.socket_opened = time.time()
+ self.server.protocol_connected(self)
@inlineCallbacks
def process(self):
+ """
+ Called when all headers have been received and we can start processing content.
+ """
+ # Get upgrade header
+ upgrade_header = None
+ if self.requestHeaders.hasHeader(b"Upgrade"):
+ upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
+ # Get client address if forwarded
+ if self.server.proxy_forwarded_address_header:
+ self.client_addr, self.client_scheme = parse_x_forwarded_for(
+ {name: value for name, value in self.requestHeaders.getAllRawHeaders()},
+ self.server.proxy_forwarded_address_header,
+ self.server.proxy_forwarded_port_header,
+ self.server.proxy_forwarded_proto_header,
+ self.client_addr,
+ self.client_scheme,
+ )
+ # Check for maximum request body size
+ if self.server.request_max_size:
+ self.channel.maxData = self.server.request_max_size
+ # Get query string
+ self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
try:
- self.request_start = time.time()
-
- # Validate header names.
- for name, _ in self.requestHeaders.getAllRawHeaders():
- if not HEADER_NAME_RE.fullmatch(name):
- self.basic_error(400, b"Bad Request", "Invalid header name")
- return
-
- # Get upgrade header
- upgrade_header = None
- if self.requestHeaders.hasHeader(b"Upgrade"):
- upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
- # Get client address if possible
- if hasattr(self.client, "host") and hasattr(self.client, "port"):
- # client.host and host.host are byte strings in Python 2, but spec
- # requires unicode string.
- self.client_addr = [str(self.client.host), self.client.port]
- self.server_addr = [str(self.host.host), self.host.port]
-
- self.client_scheme = "https" if self.isSecure() else "http"
-
- # See if we need to get the address from a proxy header instead
- if self.server.proxy_forwarded_address_header:
- self.client_addr, self.client_scheme = parse_x_forwarded_for(
- self.requestHeaders,
- self.server.proxy_forwarded_address_header,
- self.server.proxy_forwarded_port_header,
- self.server.proxy_forwarded_proto_header,
- self.client_addr,
- self.client_scheme,
- )
- # Check for unicodeish path (or it'll crash when trying to parse)
- try:
- self.path.decode("ascii")
- except UnicodeDecodeError:
- self.path = b"/"
- self.basic_error(400, b"Bad Request", "Invalid characters in path")
- return
- # Calculate query string
- self.query_string = b""
- if b"?" in self.uri:
- self.query_string = self.uri.split(b"?", 1)[1]
- try:
- self.query_string.decode("ascii")
- except UnicodeDecodeError:
- self.basic_error(400, b"Bad Request", "Invalid query string")
- return
- # Is it WebSocket? IS IT?!
+ # Process WebSocket requests via HTTP upgrade
if upgrade_header and upgrade_header.lower() == b"websocket":
- # Make WebSocket protocol to hand off to
- protocol = self.server.ws_factory.buildProtocol(
- self.transport.getPeer()
- )
- if not protocol:
- # If protocol creation fails, we signal "internal server error"
- self.setResponseCode(500)
- logger.warn("Could not make WebSocket protocol")
- self.finish()
- # Give it the raw query string
- protocol._raw_query_string = self.query_string
- # Port across transport
- transport, self.transport = self.transport, None
- if isinstance(transport, ProtocolWrapper):
- # i.e. TLS is a wrapping protocol
- transport.wrappedProtocol = protocol
- else:
- transport.protocol = protocol
- protocol.makeConnection(transport)
- # Re-inject request
- data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
- for h in self.requestHeaders.getAllRawHeaders():
- data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
- data += b"\x0d\x0a"
- data += self.content.read()
- protocol.dataReceived(data)
- # Remove our HTTP reply channel association
+ # Pass request to WebSocketResource for handling
+ self.server.ws_resource.render_GET(self)
+ # The WebSocketResource will handle the rest of the connection
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
- self.server.protocol_disconnected(self)
- # Resume the producer so we keep getting data, if it's available as a method
- self.channel._networkProducer.resumeProducing()
- # Boring old HTTP.
+ # Don't continue with HTTP processing
+ return
+ # Handle normal HTTP requests
else:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
@@ -339,9 +283,9 @@ class WebRequest(http.Request):
"""
Returns the time since the start of the request.
"""
- if not hasattr(self, "request_start"):
+ if not hasattr(self, "socket_opened"):
return 0
- return time.time() - self.request_start
+ return time.time() - self.socket_opened
def basic_error(self, status, status_text, body):
"""
@@ -357,13 +301,12 @@ class WebRequest(http.Request):
self.handle_reply(
{
"type": "http.response.body",
- "body": (
- self.error_template
- % {
- "title": str(status) + " " + status_text.decode("ascii"),
- "body": body,
- }
- ).encode("utf8"),
+ "body": self.error_template
+ % {
+ "status": status,
+ "status_text": status_text,
+ "text": body,
+ },
}
)
diff --git a/daphne/server.py b/daphne/server.py
index a6d3819..6296097 100755
--- a/daphne/server.py
+++ b/daphne/server.py
@@ -37,6 +37,7 @@ from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
+from twisted.web.websocket import WebSocketResource
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
@@ -77,6 +78,7 @@ class Server:
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size
+ self.request_max_size = None # No limit by default
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
@@ -99,12 +101,14 @@ class Server:
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
- self.ws_factory = WebSocketFactory(self, server=self.server_name)
- self.ws_factory.setProtocolOptions(
- autoPingTimeout=self.ping_timeout,
- allowNullOrigin=True,
- openHandshakeTimeout=self.websocket_handshake_timeout,
- )
+
+ # Create WebSocket factory
+ self.ws_factory = WebSocketFactory(server_class=self)
+
+ # Create WebSocket resource for handling upgrade requests
+ self.ws_resource = WebSocketResource(self.ws_factory)
+
+ # Configure logging
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo(
diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py
index b1e29c3..9b203c7 100755
--- a/daphne/ws_protocol.py
+++ b/daphne/ws_protocol.py
@@ -3,19 +3,15 @@ import time
import traceback
from urllib.parse import unquote
-from autobahn.twisted.websocket import (
- ConnectionDeny,
- WebSocketServerFactory,
- WebSocketServerProtocol,
-)
from twisted.internet import defer
+from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
-class WebSocketProtocol(WebSocketServerProtocol):
+class WebSocketProtocol(TwistedWebSocketProtocol):
"""
Protocol which supports WebSockets and forwards incoming messages to
the websocket channels.
@@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol):
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
- def onConnect(self, request):
+ def __init__(self, factory, request):
+ self.factory = factory
+ self.transport = None
self.server = self.factory.server_class
- self.server.protocol_connected(self)
- self.request = request
- self.protocol_to_accept = None
- self.root_path = self.server.root_path
self.socket_opened = time.time()
self.last_ping = time.time()
+ self.client_addr = None
+ self.server_addr = None
+ self.clean_headers = []
+ self.handshake_deferred = None
+ self.path = None
+ self.root_path = None
+ self.application_queue = None
+ self.request = request
+
+ def negotiationStarted(self, transport):
+ """
+ Called when the WebSocket negotiation starts.
+ """
+ self.transport = transport
+ self.server.protocol_connected(self)
+ self.protocol_to_accept = None
+ self.root_path = self.server.root_path
try:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
- for name, value in request.headers.items():
- name = name.encode("ascii")
+ for name, value in self.request.requestHeaders.getAllRawHeaders():
+ name = name.lower()
# Prevent CVE-2015-0219
if b"_" in name:
continue
- if name.lower() == b"daphne-root-path":
- self.root_path = unquote(value)
+ if name == b"daphne-root-path":
+ self.root_path = unquote(value[0].decode("ascii"))
else:
- self.clean_headers.append((name.lower(), value.encode("latin1")))
+ self.clean_headers.append((name, value[0]))
+
# Get client address if possible
- peer = self.transport.getPeer()
- host = self.transport.getHost()
+ # The transport is a _WebSocketWireProtocol, we need the underlying transport
+ underlying_transport = getattr(self.transport, "transport", self.transport)
+ peer = underlying_transport.getPeer()
+ host = underlying_transport.getHost()
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
@@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_proto_header,
self.client_addr,
)
+
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
@@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
subprotocols = [
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
+
+ # Extract query string
+ query_string = b""
+ if b"?" in self.request.uri:
+ query_string = self.request.uri.split(b"?", 1)[1]
+
+ # Get the path
+ self.path = self.request.path
+
# Make new application instance with scope
- self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(
self.server.create_application,
self,
@@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"raw_path": self.path,
"root_path": self.root_path,
"headers": self.clean_headers,
- "query_string": self._raw_query_string, # Passed by HTTP protocol
+ "query_string": query_string,
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
@@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(traceback.format_exc())
raise
- # Make a deferred and return it - we'll either call it or err it later on
- self.handshake_deferred = defer.Deferred()
- return self.handshake_deferred
-
def applicationCreateWorked(self, application_queue):
"""
Called when the background thread has successfully made the application
@@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"websocket",
"connecting",
{
- "path": self.request.path,
+ "path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
@@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(failure)
return failure
- ### Twisted event handling
-
- def onOpen(self):
- # Send news that this channel is open
+ def negotiationFinished(self):
+ """
+ Called when the WebSocket negotiation is finished.
+ """
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action(
"websocket",
"connected",
{
- "path": self.request.path,
+ "path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
- def onMessage(self, payload, isBinary):
+ def textMessageReceived(self, message):
+ """
+ Called when a text message is received.
+ """
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
- if isBinary:
- self.application_queue.put_nowait(
- {"type": "websocket.receive", "bytes": payload}
- )
- else:
- self.application_queue.put_nowait(
- {"type": "websocket.receive", "text": payload.decode("utf8")}
- )
+ self.application_queue.put_nowait(
+ {"type": "websocket.receive", "text": message}
+ )
- def onClose(self, wasClean, code, reason):
+ def bytesMessageReceived(self, data):
"""
- Called when Twisted closes the socket.
+ Called when a binary message is received.
+ """
+ # If we're muted, do nothing.
+ if self.muted:
+ logger.debug("Muting incoming frame on %s", self.client_addr)
+ return
+ logger.debug("WebSocket incoming frame on %s", self.client_addr)
+ self.last_ping = time.time()
+ self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data})
+
+ def connectionLost(self, reason):
+ """
+ Called when the WebSocket connection is lost.
"""
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait(
- {"type": "websocket.disconnect", "code": code}
+ {"type": "websocket.disconnect", "code": 1000} # Default close code
)
self.server.log_action(
"websocket",
"disconnected",
{
- "path": self.request.path,
+ "path": self.request.path.decode("ascii"),
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
+ def pongReceived(self, payload):
+ """
+ Called when a pong frame is received in response to a ping.
+ """
+ self.last_ping = time.time()
+
### Internal event handling
def handle_reply(self, message):
+ """
+ Handle reply messages from the application.
+ """
if "type" not in message:
raise ValueError("Message has no type defined")
+
if message["type"] == "websocket.accept":
- self.serverAccept(message.get("subprotocol", None))
+ # Accept is handled by WebSocketResource in Twisted 25
+ # Our protocol is already established at this point
+ pass
elif message["type"] == "websocket.close":
- if self.state == self.STATE_CONNECTING:
- self.serverReject()
- else:
- self.serverClose(code=message.get("code", None))
+ self.transport.loseConnection(code=message.get("code", 1000))
elif message["type"] == "websocket.send":
- if self.state == self.STATE_CONNECTING:
- raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
- % (message,)
+ % (self.client_addr,)
)
if message.get("bytes", None):
- self.serverSend(message["bytes"], True)
+ self.transport.sendBytesMessage(message["bytes"])
if message.get("text", None):
- self.serverSend(message["text"], False)
+ self.transport.sendTextMessage(message["text"])
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
- if hasattr(self, "handshake_deferred"):
- # If the handshake is still ongoing, we need to emit a HTTP error
- # code rather than a WebSocket one.
- self.handshake_deferred.errback(
- ConnectionDeny(code=500, reason="Internal server error")
- )
- else:
- self.sendCloseFrame(code=1011)
-
- def serverAccept(self, subprotocol=None):
- """
- Called when we get a message saying to accept the connection.
- """
- self.handshake_deferred.callback(subprotocol)
- del self.handshake_deferred
- logger.debug("WebSocket %s accepted by application", self.client_addr)
-
- def serverReject(self):
- """
- Called when we get a message saying to reject the connection.
- """
- self.handshake_deferred.errback(
- ConnectionDeny(code=403, reason="Access denied")
- )
- del self.handshake_deferred
- self.server.protocol_disconnected(self)
- logger.debug("WebSocket %s rejected by application", self.client_addr)
- self.server.log_action(
- "websocket",
- "rejected",
- {
- "path": self.request.path,
- "client": (
- "%s:%s" % tuple(self.client_addr) if self.client_addr else None
- ),
- },
- )
-
- def serverSend(self, content, binary=False):
- """
- Server-side channel message to send a message.
- """
- if self.state == self.STATE_CONNECTING:
- self.serverAccept()
- logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
- if binary:
- self.sendMessage(content, binary)
- else:
- self.sendMessage(content.encode("utf8"), binary)
-
- def serverClose(self, code=None):
- """
- Server-side channel message to close the socket
- """
- code = 1000 if code is None else code
- self.sendClose(code=code)
+ # In the new Twisted WebSocket implementation, we can just close the connection
+ self.transport.loseConnection(code=1011) # Internal server error
### Utils
@@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
- self.serverClose()
+ self.transport.loseConnection(code=1000)
+
# Ping check
- # If we're still connecting, deny the connection
- if self.state == self.STATE_CONNECTING:
- if self.duration() > self.server.websocket_connect_timeout:
- self.serverReject()
- elif self.state == self.STATE_OPEN:
+ if hasattr(self, "transport") and self.transport:
if (time.time() - self.last_ping) > self.server.ping_interval:
- self._sendAutoPing()
+ self.transport.ping()
self.last_ping = time.time()
def __hash__(self):
@@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol):
return f""
-class WebSocketFactory(WebSocketServerFactory):
+class WebSocketFactory:
"""
- Factory subclass that remembers what the "main"
- factory is, so WebSocket protocols can access it
- to get reply ID info.
+ Factory for WebSocket protocols.
"""
- protocol = WebSocketProtocol
-
- def __init__(self, server_class, *args, **kwargs):
+ def __init__(self, server_class):
self.server_class = server_class
- WebSocketServerFactory.__init__(self, *args, **kwargs)
- def buildProtocol(self, addr):
+ def buildProtocol(self, request):
"""
- Builds protocol instances. We use this to inject the factory object into the protocol.
+ Builds a new WebSocket protocol.
"""
try:
- protocol = super().buildProtocol(addr)
- protocol.factory = self
+ protocol = WebSocketProtocol(self, request)
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
diff --git a/pyproject.toml b/pyproject.toml
index 9b410aa..5e38150 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,7 +24,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP",
]
-dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
+dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"]
[project.optional-dependencies]
tests = [