diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index e9eba96..b554adc 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -5,11 +5,10 @@ from urllib.parse import unquote from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.internet.interfaces import IProtocolNegotiationFactory -from twisted.protocols.policies import ProtocolWrapper from twisted.web import http from zope.interface import implementer -from .utils import HEADER_NAME_RE, parse_x_forwarded_for +from .utils import parse_x_forwarded_for logger = logging.getLogger(__name__) @@ -24,132 +23,77 @@ class WebRequest(http.Request): """ error_template = ( - """ - - - %(title)s - - - -

%(title)s

-

%(body)s

- - - - """.replace( - "\n", "" - ) - .replace(" ", " ") - .replace(" ", " ") - .replace(" ", " ") - ) # Shorten it a bit, bytes wise + b"" + b"" + b"%(status)d %(status_text)s" + b"

%(status)d %(status_text)s

%(text)s" + b"" + ) def __init__(self, *args, **kwargs): + http.Request.__init__(self, *args, **kwargs) + # Easy server link + self.server = self.channel.factory.server + self.application_queue = None + self._response_started = False self.client_addr = None self.server_addr = None - try: - http.Request.__init__(self, *args, **kwargs) - # Easy server link - self.server = self.channel.factory.server - self.application_queue = None - self._response_started = False - self.server.protocol_connected(self) - except Exception: - logger.error(traceback.format_exc()) - raise - - ### Twisted progress callbacks + self.client_scheme = None + # Build the client address + if self.transport: + peer = self.transport.getPeer() + host = self.transport.getHost() + # Always set scheme if we have a transport + self.client_scheme = ( + "https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http" + ) + if hasattr(peer, "host") and hasattr(peer, "port"): + self.client_addr = [str(peer.host), peer.port] + self.server_addr = [str(host.host), host.port] + # Get upgrade header + upgrade_header = None + if self.requestHeaders.hasHeader(b"Upgrade"): + upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] + self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket" + # Hook up request parsing + self.socket_opened = time.time() + self.server.protocol_connected(self) @inlineCallbacks def process(self): + """ + Called when all headers have been received and we can start processing content. + """ + # Get upgrade header + upgrade_header = None + if self.requestHeaders.hasHeader(b"Upgrade"): + upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] + # Get client address if forwarded + if self.server.proxy_forwarded_address_header: + self.client_addr, self.client_scheme = parse_x_forwarded_for( + {name: value for name, value in self.requestHeaders.getAllRawHeaders()}, + self.server.proxy_forwarded_address_header, + self.server.proxy_forwarded_port_header, + self.server.proxy_forwarded_proto_header, + self.client_addr, + self.client_scheme, + ) + # Check for maximum request body size + if self.server.request_max_size: + self.channel.maxData = self.server.request_max_size + # Get query string + self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b"" try: - self.request_start = time.time() - - # Validate header names. - for name, _ in self.requestHeaders.getAllRawHeaders(): - if not HEADER_NAME_RE.fullmatch(name): - self.basic_error(400, b"Bad Request", "Invalid header name") - return - - # Get upgrade header - upgrade_header = None - if self.requestHeaders.hasHeader(b"Upgrade"): - upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] - # Get client address if possible - if hasattr(self.client, "host") and hasattr(self.client, "port"): - # client.host and host.host are byte strings in Python 2, but spec - # requires unicode string. - self.client_addr = [str(self.client.host), self.client.port] - self.server_addr = [str(self.host.host), self.host.port] - - self.client_scheme = "https" if self.isSecure() else "http" - - # See if we need to get the address from a proxy header instead - if self.server.proxy_forwarded_address_header: - self.client_addr, self.client_scheme = parse_x_forwarded_for( - self.requestHeaders, - self.server.proxy_forwarded_address_header, - self.server.proxy_forwarded_port_header, - self.server.proxy_forwarded_proto_header, - self.client_addr, - self.client_scheme, - ) - # Check for unicodeish path (or it'll crash when trying to parse) - try: - self.path.decode("ascii") - except UnicodeDecodeError: - self.path = b"/" - self.basic_error(400, b"Bad Request", "Invalid characters in path") - return - # Calculate query string - self.query_string = b"" - if b"?" in self.uri: - self.query_string = self.uri.split(b"?", 1)[1] - try: - self.query_string.decode("ascii") - except UnicodeDecodeError: - self.basic_error(400, b"Bad Request", "Invalid query string") - return - # Is it WebSocket? IS IT?! + # Process WebSocket requests via HTTP upgrade if upgrade_header and upgrade_header.lower() == b"websocket": - # Make WebSocket protocol to hand off to - protocol = self.server.ws_factory.buildProtocol( - self.transport.getPeer() - ) - if not protocol: - # If protocol creation fails, we signal "internal server error" - self.setResponseCode(500) - logger.warn("Could not make WebSocket protocol") - self.finish() - # Give it the raw query string - protocol._raw_query_string = self.query_string - # Port across transport - transport, self.transport = self.transport, None - if isinstance(transport, ProtocolWrapper): - # i.e. TLS is a wrapping protocol - transport.wrappedProtocol = protocol - else: - transport.protocol = protocol - protocol.makeConnection(transport) - # Re-inject request - data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a" - for h in self.requestHeaders.getAllRawHeaders(): - data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a" - data += b"\x0d\x0a" - data += self.content.read() - protocol.dataReceived(data) - # Remove our HTTP reply channel association + # Pass request to WebSocketResource for handling + self.server.ws_resource.render_GET(self) + # The WebSocketResource will handle the rest of the connection logger.debug("Upgraded connection %s to WebSocket", self.client_addr) - self.server.protocol_disconnected(self) - # Resume the producer so we keep getting data, if it's available as a method - self.channel._networkProducer.resumeProducing() - # Boring old HTTP. + # Don't continue with HTTP processing + return + # Handle normal HTTP requests else: # Sanitize and decode headers, potentially extracting root path self.clean_headers = [] @@ -339,9 +283,9 @@ class WebRequest(http.Request): """ Returns the time since the start of the request. """ - if not hasattr(self, "request_start"): + if not hasattr(self, "socket_opened"): return 0 - return time.time() - self.request_start + return time.time() - self.socket_opened def basic_error(self, status, status_text, body): """ @@ -357,13 +301,12 @@ class WebRequest(http.Request): self.handle_reply( { "type": "http.response.body", - "body": ( - self.error_template - % { - "title": str(status) + " " + status_text.decode("ascii"), - "body": body, - } - ).encode("utf8"), + "body": self.error_template + % { + "status": status, + "status_text": status_text, + "text": body, + }, } ) diff --git a/daphne/server.py b/daphne/server.py index a6d3819..6296097 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -37,6 +37,7 @@ from twisted.internet import defer, reactor from twisted.internet.endpoints import serverFromString from twisted.logger import STDLibLogObserver, globalLogBeginner from twisted.web import http +from twisted.web.websocket import WebSocketResource from .http_protocol import HTTPFactory from .ws_protocol import WebSocketFactory @@ -77,6 +78,7 @@ class Server: self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.request_buffer_size = request_buffer_size + self.request_max_size = None # No limit by default self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_port_header = proxy_forwarded_port_header self.proxy_forwarded_proto_header = proxy_forwarded_proto_header @@ -99,12 +101,14 @@ class Server: self.connections = {} # Make the factory self.http_factory = HTTPFactory(self) - self.ws_factory = WebSocketFactory(self, server=self.server_name) - self.ws_factory.setProtocolOptions( - autoPingTimeout=self.ping_timeout, - allowNullOrigin=True, - openHandshakeTimeout=self.websocket_handshake_timeout, - ) + + # Create WebSocket factory + self.ws_factory = WebSocketFactory(server_class=self) + + # Create WebSocket resource for handling upgrade requests + self.ws_resource = WebSocketResource(self.ws_factory) + + # Configure logging if self.verbosity <= 1: # Redirect the Twisted log to nowhere globalLogBeginner.beginLoggingTo( diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index b1e29c3..9b203c7 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -3,19 +3,15 @@ import time import traceback from urllib.parse import unquote -from autobahn.twisted.websocket import ( - ConnectionDeny, - WebSocketServerFactory, - WebSocketServerProtocol, -) from twisted.internet import defer +from twisted.web.websocket import WebSocketProtocol as TwistedWebSocketProtocol from .utils import parse_x_forwarded_for logger = logging.getLogger(__name__) -class WebSocketProtocol(WebSocketServerProtocol): +class WebSocketProtocol(TwistedWebSocketProtocol): """ Protocol which supports WebSockets and forwards incoming messages to the websocket channels. @@ -26,29 +22,47 @@ class WebSocketProtocol(WebSocketServerProtocol): # If we should send no more messages (e.g. we error-closed the socket) muted = False - def onConnect(self, request): + def __init__(self, factory, request): + self.factory = factory + self.transport = None self.server = self.factory.server_class - self.server.protocol_connected(self) - self.request = request - self.protocol_to_accept = None - self.root_path = self.server.root_path self.socket_opened = time.time() self.last_ping = time.time() + self.client_addr = None + self.server_addr = None + self.clean_headers = [] + self.handshake_deferred = None + self.path = None + self.root_path = None + self.application_queue = None + self.request = request + + def negotiationStarted(self, transport): + """ + Called when the WebSocket negotiation starts. + """ + self.transport = transport + self.server.protocol_connected(self) + self.protocol_to_accept = None + self.root_path = self.server.root_path try: # Sanitize and decode headers, potentially extracting root path self.clean_headers = [] - for name, value in request.headers.items(): - name = name.encode("ascii") + for name, value in self.request.requestHeaders.getAllRawHeaders(): + name = name.lower() # Prevent CVE-2015-0219 if b"_" in name: continue - if name.lower() == b"daphne-root-path": - self.root_path = unquote(value) + if name == b"daphne-root-path": + self.root_path = unquote(value[0].decode("ascii")) else: - self.clean_headers.append((name.lower(), value.encode("latin1"))) + self.clean_headers.append((name, value[0])) + # Get client address if possible - peer = self.transport.getPeer() - host = self.transport.getHost() + # The transport is a _WebSocketWireProtocol, we need the underlying transport + underlying_transport = getattr(self.transport, "transport", self.transport) + peer = underlying_transport.getPeer() + host = underlying_transport.getHost() if hasattr(peer, "host") and hasattr(peer, "port"): self.client_addr = [str(peer.host), peer.port] self.server_addr = [str(host.host), host.port] @@ -64,6 +78,7 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.proxy_forwarded_proto_header, self.client_addr, ) + # Decode websocket subprotocol options subprotocols = [] for header, value in self.clean_headers: @@ -71,8 +86,16 @@ class WebSocketProtocol(WebSocketServerProtocol): subprotocols = [ x.strip() for x in unquote(value.decode("ascii")).split(",") ] + + # Extract query string + query_string = b"" + if b"?" in self.request.uri: + query_string = self.request.uri.split(b"?", 1)[1] + + # Get the path + self.path = self.request.path + # Make new application instance with scope - self.path = request.path.encode("ascii") self.application_deferred = defer.maybeDeferred( self.server.create_application, self, @@ -82,7 +105,7 @@ class WebSocketProtocol(WebSocketServerProtocol): "raw_path": self.path, "root_path": self.root_path, "headers": self.clean_headers, - "query_string": self._raw_query_string, # Passed by HTTP protocol + "query_string": query_string, "client": self.client_addr, "server": self.server_addr, "subprotocols": subprotocols, @@ -97,10 +120,6 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.error(traceback.format_exc()) raise - # Make a deferred and return it - we'll either call it or err it later on - self.handshake_deferred = defer.Deferred() - return self.handshake_deferred - def applicationCreateWorked(self, application_queue): """ Called when the background thread has successfully made the application @@ -114,7 +133,7 @@ class WebSocketProtocol(WebSocketServerProtocol): "websocket", "connecting", { - "path": self.request.path, + "path": self.request.path.decode("ascii"), "client": ( "%s:%s" % tuple(self.client_addr) if self.client_addr else None ), @@ -128,144 +147,107 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.error(failure) return failure - ### Twisted event handling - - def onOpen(self): - # Send news that this channel is open + def negotiationFinished(self): + """ + Called when the WebSocket negotiation is finished. + """ logger.debug("WebSocket %s open and established", self.client_addr) self.server.log_action( "websocket", "connected", { - "path": self.request.path, + "path": self.request.path.decode("ascii"), "client": ( "%s:%s" % tuple(self.client_addr) if self.client_addr else None ), }, ) - def onMessage(self, payload, isBinary): + def textMessageReceived(self, message): + """ + Called when a text message is received. + """ # If we're muted, do nothing. if self.muted: logger.debug("Muting incoming frame on %s", self.client_addr) return logger.debug("WebSocket incoming frame on %s", self.client_addr) self.last_ping = time.time() - if isBinary: - self.application_queue.put_nowait( - {"type": "websocket.receive", "bytes": payload} - ) - else: - self.application_queue.put_nowait( - {"type": "websocket.receive", "text": payload.decode("utf8")} - ) + self.application_queue.put_nowait( + {"type": "websocket.receive", "text": message} + ) - def onClose(self, wasClean, code, reason): + def bytesMessageReceived(self, data): """ - Called when Twisted closes the socket. + Called when a binary message is received. + """ + # If we're muted, do nothing. + if self.muted: + logger.debug("Muting incoming frame on %s", self.client_addr) + return + logger.debug("WebSocket incoming frame on %s", self.client_addr) + self.last_ping = time.time() + self.application_queue.put_nowait({"type": "websocket.receive", "bytes": data}) + + def connectionLost(self, reason): + """ + Called when the WebSocket connection is lost. """ self.server.protocol_disconnected(self) logger.debug("WebSocket closed for %s", self.client_addr) if not self.muted and hasattr(self, "application_queue"): self.application_queue.put_nowait( - {"type": "websocket.disconnect", "code": code} + {"type": "websocket.disconnect", "code": 1000} # Default close code ) self.server.log_action( "websocket", "disconnected", { - "path": self.request.path, + "path": self.request.path.decode("ascii"), "client": ( "%s:%s" % tuple(self.client_addr) if self.client_addr else None ), }, ) + def pongReceived(self, payload): + """ + Called when a pong frame is received in response to a ping. + """ + self.last_ping = time.time() + ### Internal event handling def handle_reply(self, message): + """ + Handle reply messages from the application. + """ if "type" not in message: raise ValueError("Message has no type defined") + if message["type"] == "websocket.accept": - self.serverAccept(message.get("subprotocol", None)) + # Accept is handled by WebSocketResource in Twisted 25 + # Our protocol is already established at this point + pass elif message["type"] == "websocket.close": - if self.state == self.STATE_CONNECTING: - self.serverReject() - else: - self.serverClose(code=message.get("code", None)) + self.transport.loseConnection(code=message.get("code", 1000)) elif message["type"] == "websocket.send": - if self.state == self.STATE_CONNECTING: - raise ValueError("Socket has not been accepted, so cannot send over it") if message.get("bytes", None) and message.get("text", None): raise ValueError( "Got invalid WebSocket reply message on %s - contains both bytes and text keys" - % (message,) + % (self.client_addr,) ) if message.get("bytes", None): - self.serverSend(message["bytes"], True) + self.transport.sendBytesMessage(message["bytes"]) if message.get("text", None): - self.serverSend(message["text"], False) + self.transport.sendTextMessage(message["text"]) def handle_exception(self, exception): """ Called by the server when our application tracebacks """ - if hasattr(self, "handshake_deferred"): - # If the handshake is still ongoing, we need to emit a HTTP error - # code rather than a WebSocket one. - self.handshake_deferred.errback( - ConnectionDeny(code=500, reason="Internal server error") - ) - else: - self.sendCloseFrame(code=1011) - - def serverAccept(self, subprotocol=None): - """ - Called when we get a message saying to accept the connection. - """ - self.handshake_deferred.callback(subprotocol) - del self.handshake_deferred - logger.debug("WebSocket %s accepted by application", self.client_addr) - - def serverReject(self): - """ - Called when we get a message saying to reject the connection. - """ - self.handshake_deferred.errback( - ConnectionDeny(code=403, reason="Access denied") - ) - del self.handshake_deferred - self.server.protocol_disconnected(self) - logger.debug("WebSocket %s rejected by application", self.client_addr) - self.server.log_action( - "websocket", - "rejected", - { - "path": self.request.path, - "client": ( - "%s:%s" % tuple(self.client_addr) if self.client_addr else None - ), - }, - ) - - def serverSend(self, content, binary=False): - """ - Server-side channel message to send a message. - """ - if self.state == self.STATE_CONNECTING: - self.serverAccept() - logger.debug("Sent WebSocket packet to client for %s", self.client_addr) - if binary: - self.sendMessage(content, binary) - else: - self.sendMessage(content.encode("utf8"), binary) - - def serverClose(self, code=None): - """ - Server-side channel message to close the socket - """ - code = 1000 if code is None else code - self.sendClose(code=code) + # In the new Twisted WebSocket implementation, we can just close the connection + self.transport.loseConnection(code=1011) # Internal server error ### Utils @@ -284,15 +266,12 @@ class WebSocketProtocol(WebSocketServerProtocol): self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0 ): - self.serverClose() + self.transport.loseConnection(code=1000) + # Ping check - # If we're still connecting, deny the connection - if self.state == self.STATE_CONNECTING: - if self.duration() > self.server.websocket_connect_timeout: - self.serverReject() - elif self.state == self.STATE_OPEN: + if hasattr(self, "transport") and self.transport: if (time.time() - self.last_ping) > self.server.ping_interval: - self._sendAutoPing() + self.transport.ping() self.last_ping = time.time() def __hash__(self): @@ -305,26 +284,20 @@ class WebSocketProtocol(WebSocketServerProtocol): return f"" -class WebSocketFactory(WebSocketServerFactory): +class WebSocketFactory: """ - Factory subclass that remembers what the "main" - factory is, so WebSocket protocols can access it - to get reply ID info. + Factory for WebSocket protocols. """ - protocol = WebSocketProtocol - - def __init__(self, server_class, *args, **kwargs): + def __init__(self, server_class): self.server_class = server_class - WebSocketServerFactory.__init__(self, *args, **kwargs) - def buildProtocol(self, addr): + def buildProtocol(self, request): """ - Builds protocol instances. We use this to inject the factory object into the protocol. + Builds a new WebSocket protocol. """ try: - protocol = super().buildProtocol(addr) - protocol.factory = self + protocol = WebSocketProtocol(self, request) return protocol except Exception: logger.error("Cannot build protocol: %s" % traceback.format_exc()) diff --git a/pyproject.toml b/pyproject.toml index 9b410aa..5e38150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP", ] -dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"] +dependencies = ["asgiref>=3.5.2,<4", "twisted[tls,websocket]>=22.4"] [project.optional-dependencies] tests = [