Fix WebSocket protocol to correctly provide more info.

Also removed keepalive code, since it's no longer how we do things.
This commit is contained in:
Andrew Godwin 2016-02-09 12:35:31 -08:00
parent 40ed9625c3
commit 1f46fdff12
3 changed files with 48 additions and 48 deletions

View File

@ -40,13 +40,6 @@ class WebRequest(http.Request):
self.query_string = "" self.query_string = ""
if b"?" in self.uri: if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1] self.query_string = self.uri.split(b"?", 1)[1]
# Sanitize headers
self.clean_headers = {}
for name, value in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
self.clean_headers[name.lower().decode("latin1")] = value[0]
# Is it WebSocket? IS IT?! # Is it WebSocket? IS IT?!
if upgrade_header == "websocket": if upgrade_header == "websocket":
# Make WebSocket protocol to hand off to # Make WebSocket protocol to hand off to
@ -54,6 +47,7 @@ class WebRequest(http.Request):
if not protocol: if not protocol:
# If protocol creation fails, we signal "internal server error" # If protocol creation fails, we signal "internal server error"
self.setResponseCode(500) self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish() self.finish()
# Port across transport # Port across transport
protocol.set_main_factory(self.factory) protocol.set_main_factory(self.factory)
@ -72,12 +66,19 @@ class WebRequest(http.Request):
data += self.content.read() data += self.content.read()
protocol.dataReceived(data) protocol.dataReceived(data)
# Remove our HTTP reply channel association # Remove our HTTP reply channel association
logging.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
self.factory.reply_protocols[self.reply_channel] = None self.factory.reply_protocols[self.reply_channel] = None
self.reply_channel = None self.reply_channel = None
# Boring old HTTP. # Boring old HTTP.
else: else:
logging.debug("HTTP %s request for %s", self.method, self.reply_channel) # Sanitize and decode headers
self.clean_headers = {}
for name, value in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
self.clean_headers[name.lower().decode("latin1")] = value[0]
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0) self.content.seek(0, 0)
# Send message # Send message
self.factory.channel_layer.send("http.request", { self.factory.channel_layer.send("http.request", {
@ -100,7 +101,7 @@ class WebRequest(http.Request):
""" """
if self.reply_channel: if self.reply_channel:
del self.channel.factory.reply_protocols[self.reply_channel] del self.channel.factory.reply_protocols[self.reply_channel]
logging.debug("HTTP disconnect for %s", self.reply_channel) logger.debug("HTTP disconnect for %s", self.reply_channel)
http.Request.connectionLost(self, reason) http.Request.connectionLost(self, reason)
def serverResponse(self, message): def serverResponse(self, message):
@ -122,9 +123,9 @@ class WebRequest(http.Request):
# End if there's no more content # End if there's no more content
if not message.get("more_content", False): if not message.get("more_content", False):
self.finish() self.finish()
logging.debug("HTTP %s response for %s", message['status'], self.reply_channel) logger.debug("HTTP %s response for %s", message['status'], self.reply_channel)
else: else:
logging.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel) logger.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
class HTTPProtocol(http.HTTPChannel): class HTTPProtocol(http.HTTPChannel):

View File

@ -15,7 +15,6 @@ class Server(object):
self.factory = HTTPFactory(self.channel_layer) self.factory = HTTPFactory(self.channel_layer)
reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.listenTCP(self.port, self.factory, interface=self.host)
reactor.callInThread(self.backend_reader) reactor.callInThread(self.backend_reader)
#reactor.callLater(1, self.keepalive_sender)
reactor.run() reactor.run()
def backend_reader(self): def backend_reader(self):
@ -39,14 +38,3 @@ class Server(object):
continue continue
# Deal with the message # Deal with the message
self.factory.dispatch_reply(channel, message) self.factory.dispatch_reply(channel, message)
def keepalive_sender(self):
"""
Sends keepalive messages for open WebSockets every
(channel_backend expiry / 2) seconds.
"""
expiry_window = int(self.channel_layer.group_expiry / 2)
for protocol in self.factory.reply_protocols.values():
if time.time() - protocol.last_keepalive > expiry_window:
protocol.sendKeepalive()
reactor.callLater(1, self.keepalive_sender)

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import time
import logging import logging
import time
import traceback
import urllib
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
@ -19,21 +21,37 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.channel_layer = self.main_factory.channel_layer self.channel_layer = self.main_factory.channel_layer
def onConnect(self, request): def onConnect(self, request):
self.request_info = { try:
"path": request.path, # Sanitize and decode headers
"headers": self.clean_headers, clean_headers = {}
"query_string": request.query_string, for name, value in request.headers.items():
"client": [request.client.host, request.client.port], # Prevent CVE-2015-0219
"server": [request.host.host, request.host.port], if "_" in name:
} continue
clean_headers[name.lower()] = value[0].encode("latin1")
def onOpen(self): # Reconstruct query string
# TODO: get autobahn to provide it raw
query_string = urllib.urlencode(request.params)
# Make sending channel # Make sending channel
self.reply_channel = self.channel_layer.new_channel("!websocket.send.?") self.reply_channel = self.channel_layer.new_channel("!websocket.send.?")
self.request_info["reply_channel"] = self.reply_channel
self.last_keepalive = time.time()
# Tell main factory about it # Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self self.main_factory.reply_protocols[self.reply_channel] = self
# Make initial request info dict from request (we only have it here)
self.request_info = {
"path": request.path,
"headers": clean_headers,
"query_string": query_string,
"client": [self.transport.getPeer().host, self.transport.getPeer().port],
"server": [self.transport.getHost().host, self.transport.getHost().port],
"reply_channel": self.reply_channel,
}
except:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
raise
def onOpen(self):
# Send news that this channel is open # Send news that this channel is open
logger.debug("WebSocket open for %s", self.reply_channel) logger.debug("WebSocket open for %s", self.reply_channel)
self.channel_layer.send("websocket.connect", self.request_info) self.channel_layer.send("websocket.connect", self.request_info)
@ -68,21 +86,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.sendClose() self.sendClose()
def onClose(self, wasClean, code, reason): def onClose(self, wasClean, code, reason):
logger.debug("WebSocket closed for %s", self.reply_channel)
if hasattr(self, "reply_channel"): if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel] del self.factory.reply_protocols[self.reply_channel]
self.channel_layer.send("websocket.disconnect", { self.channel_layer.send("websocket.disconnect", {
"reply_channel": self.reply_channel, "reply_channel": self.reply_channel,
}) })
else:
def sendKeepalive(self): logger.debug("WebSocket closed before handshake established")
"""
Sends a keepalive packet on the keepalive channel.
"""
self.channel_layer.send("websocket.keepalive", {
"reply_channel": self.reply_channel,
})
self.last_keepalive = time.time()
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):