From 1f46fdff12f3fc8e3d0839b7eab37739d768136e Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Tue, 9 Feb 2016 12:35:31 -0800 Subject: [PATCH] Fix WebSocket protocol to correctly provide more info. Also removed keepalive code, since it's no longer how we do things. --- daphne/http_protocol.py | 25 ++++++++--------- daphne/server.py | 12 --------- daphne/ws_protocol.py | 59 ++++++++++++++++++++++++----------------- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index d1358bf..621a60e 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -40,13 +40,6 @@ class WebRequest(http.Request): self.query_string = "" if b"?" in self.uri: self.query_string = self.uri.split(b"?", 1)[1] - # Sanitize headers - self.clean_headers = {} - for name, value in self.requestHeaders.getAllRawHeaders(): - # Prevent CVE-2015-0219 - if b"_" in name: - continue - self.clean_headers[name.lower().decode("latin1")] = value[0] # Is it WebSocket? IS IT?! if upgrade_header == "websocket": # Make WebSocket protocol to hand off to @@ -54,6 +47,7 @@ class WebRequest(http.Request): if not protocol: # If protocol creation fails, we signal "internal server error" self.setResponseCode(500) + logger.warn("Could not make WebSocket protocol") self.finish() # Port across transport protocol.set_main_factory(self.factory) @@ -72,12 +66,19 @@ class WebRequest(http.Request): data += self.content.read() protocol.dataReceived(data) # Remove our HTTP reply channel association - logging.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) + logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) self.factory.reply_protocols[self.reply_channel] = None self.reply_channel = None # Boring old HTTP. else: - logging.debug("HTTP %s request for %s", self.method, self.reply_channel) + # Sanitize and decode headers + self.clean_headers = {} + for name, value in self.requestHeaders.getAllRawHeaders(): + # Prevent CVE-2015-0219 + if b"_" in name: + continue + self.clean_headers[name.lower().decode("latin1")] = value[0] + logger.debug("HTTP %s request for %s", self.method, self.reply_channel) self.content.seek(0, 0) # Send message self.factory.channel_layer.send("http.request", { @@ -100,7 +101,7 @@ class WebRequest(http.Request): """ if self.reply_channel: del self.channel.factory.reply_protocols[self.reply_channel] - logging.debug("HTTP disconnect for %s", self.reply_channel) + logger.debug("HTTP disconnect for %s", self.reply_channel) http.Request.connectionLost(self, reason) def serverResponse(self, message): @@ -122,9 +123,9 @@ class WebRequest(http.Request): # End if there's no more content if not message.get("more_content", False): self.finish() - logging.debug("HTTP %s response for %s", message['status'], self.reply_channel) + logger.debug("HTTP %s response for %s", message['status'], self.reply_channel) else: - logging.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel) + logger.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel) class HTTPProtocol(http.HTTPChannel): diff --git a/daphne/server.py b/daphne/server.py index 6e656cc..0e2d432 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -15,7 +15,6 @@ class Server(object): self.factory = HTTPFactory(self.channel_layer) reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.callInThread(self.backend_reader) - #reactor.callLater(1, self.keepalive_sender) reactor.run() def backend_reader(self): @@ -39,14 +38,3 @@ class Server(object): continue # Deal with the message self.factory.dispatch_reply(channel, message) - - def keepalive_sender(self): - """ - Sends keepalive messages for open WebSockets every - (channel_backend expiry / 2) seconds. - """ - expiry_window = int(self.channel_layer.group_expiry / 2) - for protocol in self.factory.reply_protocols.values(): - if time.time() - protocol.last_keepalive > expiry_window: - protocol.sendKeepalive() - reactor.callLater(1, self.keepalive_sender) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 99b8754..4343983 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -1,7 +1,9 @@ from __future__ import unicode_literals -import time import logging +import time +import traceback +import urllib from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory @@ -19,21 +21,37 @@ class WebSocketProtocol(WebSocketServerProtocol): self.channel_layer = self.main_factory.channel_layer def onConnect(self, request): - self.request_info = { - "path": request.path, - "headers": self.clean_headers, - "query_string": request.query_string, - "client": [request.client.host, request.client.port], - "server": [request.host.host, request.host.port], - } + try: + # Sanitize and decode headers + clean_headers = {} + for name, value in request.headers.items(): + # Prevent CVE-2015-0219 + if "_" in name: + continue + clean_headers[name.lower()] = value[0].encode("latin1") + # Reconstruct query string + # TODO: get autobahn to provide it raw + query_string = urllib.urlencode(request.params) + # Make sending channel + self.reply_channel = self.channel_layer.new_channel("!websocket.send.?") + # Tell main factory about it + self.main_factory.reply_protocols[self.reply_channel] = self + # Make initial request info dict from request (we only have it here) + self.request_info = { + "path": request.path, + "headers": clean_headers, + "query_string": query_string, + "client": [self.transport.getPeer().host, self.transport.getPeer().port], + "server": [self.transport.getHost().host, self.transport.getHost().port], + "reply_channel": self.reply_channel, + } + except: + # Exceptions here are not displayed right, just 500. + # Turn them into an ERROR log. + logger.error(traceback.format_exc()) + raise def onOpen(self): - # Make sending channel - self.reply_channel = self.channel_layer.new_channel("!websocket.send.?") - self.request_info["reply_channel"] = self.reply_channel - self.last_keepalive = time.time() - # Tell main factory about it - self.main_factory.reply_protocols[self.reply_channel] = self # Send news that this channel is open logger.debug("WebSocket open for %s", self.reply_channel) self.channel_layer.send("websocket.connect", self.request_info) @@ -68,21 +86,14 @@ class WebSocketProtocol(WebSocketServerProtocol): self.sendClose() def onClose(self, wasClean, code, reason): - logger.debug("WebSocket closed for %s", self.reply_channel) if hasattr(self, "reply_channel"): + logger.debug("WebSocket closed for %s", self.reply_channel) del self.factory.reply_protocols[self.reply_channel] self.channel_layer.send("websocket.disconnect", { "reply_channel": self.reply_channel, }) - - def sendKeepalive(self): - """ - Sends a keepalive packet on the keepalive channel. - """ - self.channel_layer.send("websocket.keepalive", { - "reply_channel": self.reply_channel, - }) - self.last_keepalive = time.time() + else: + logger.debug("WebSocket closed before handshake established") class WebSocketFactory(WebSocketServerFactory):