Fix WebSocket protocol to correctly provide more info.

Also removed keepalive code, since it's no longer how we do things.
This commit is contained in:
Andrew Godwin 2016-02-09 12:35:31 -08:00
parent 40ed9625c3
commit 1f46fdff12
3 changed files with 48 additions and 48 deletions

View File

@ -40,13 +40,6 @@ class WebRequest(http.Request):
self.query_string = ""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
# Sanitize headers
self.clean_headers = {}
for name, value in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
self.clean_headers[name.lower().decode("latin1")] = value[0]
# Is it WebSocket? IS IT?!
if upgrade_header == "websocket":
# Make WebSocket protocol to hand off to
@ -54,6 +47,7 @@ class WebRequest(http.Request):
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Port across transport
protocol.set_main_factory(self.factory)
@ -72,12 +66,19 @@ class WebRequest(http.Request):
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
logging.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
self.factory.reply_protocols[self.reply_channel] = None
self.reply_channel = None
# Boring old HTTP.
else:
logging.debug("HTTP %s request for %s", self.method, self.reply_channel)
# Sanitize and decode headers
self.clean_headers = {}
for name, value in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
self.clean_headers[name.lower().decode("latin1")] = value[0]
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0)
# Send message
self.factory.channel_layer.send("http.request", {
@ -100,7 +101,7 @@ class WebRequest(http.Request):
"""
if self.reply_channel:
del self.channel.factory.reply_protocols[self.reply_channel]
logging.debug("HTTP disconnect for %s", self.reply_channel)
logger.debug("HTTP disconnect for %s", self.reply_channel)
http.Request.connectionLost(self, reason)
def serverResponse(self, message):
@ -122,9 +123,9 @@ class WebRequest(http.Request):
# End if there's no more content
if not message.get("more_content", False):
self.finish()
logging.debug("HTTP %s response for %s", message['status'], self.reply_channel)
logger.debug("HTTP %s response for %s", message['status'], self.reply_channel)
else:
logging.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
logger.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
class HTTPProtocol(http.HTTPChannel):

View File

@ -15,7 +15,6 @@ class Server(object):
self.factory = HTTPFactory(self.channel_layer)
reactor.listenTCP(self.port, self.factory, interface=self.host)
reactor.callInThread(self.backend_reader)
#reactor.callLater(1, self.keepalive_sender)
reactor.run()
def backend_reader(self):
@ -39,14 +38,3 @@ class Server(object):
continue
# Deal with the message
self.factory.dispatch_reply(channel, message)
def keepalive_sender(self):
"""
Sends keepalive messages for open WebSockets every
(channel_backend expiry / 2) seconds.
"""
expiry_window = int(self.channel_layer.group_expiry / 2)
for protocol in self.factory.reply_protocols.values():
if time.time() - protocol.last_keepalive > expiry_window:
protocol.sendKeepalive()
reactor.callLater(1, self.keepalive_sender)

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals
import time
import logging
import time
import traceback
import urllib
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
@ -19,21 +21,37 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.channel_layer = self.main_factory.channel_layer
def onConnect(self, request):
self.request_info = {
"path": request.path,
"headers": self.clean_headers,
"query_string": request.query_string,
"client": [request.client.host, request.client.port],
"server": [request.host.host, request.host.port],
}
def onOpen(self):
try:
# Sanitize and decode headers
clean_headers = {}
for name, value in request.headers.items():
# Prevent CVE-2015-0219
if "_" in name:
continue
clean_headers[name.lower()] = value[0].encode("latin1")
# Reconstruct query string
# TODO: get autobahn to provide it raw
query_string = urllib.urlencode(request.params)
# Make sending channel
self.reply_channel = self.channel_layer.new_channel("!websocket.send.?")
self.request_info["reply_channel"] = self.reply_channel
self.last_keepalive = time.time()
# Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self
# Make initial request info dict from request (we only have it here)
self.request_info = {
"path": request.path,
"headers": clean_headers,
"query_string": query_string,
"client": [self.transport.getPeer().host, self.transport.getPeer().port],
"server": [self.transport.getHost().host, self.transport.getHost().port],
"reply_channel": self.reply_channel,
}
except:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
raise
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket open for %s", self.reply_channel)
self.channel_layer.send("websocket.connect", self.request_info)
@ -68,21 +86,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.sendClose()
def onClose(self, wasClean, code, reason):
logger.debug("WebSocket closed for %s", self.reply_channel)
if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.channel_layer.send("websocket.disconnect", {
"reply_channel": self.reply_channel,
})
def sendKeepalive(self):
"""
Sends a keepalive packet on the keepalive channel.
"""
self.channel_layer.send("websocket.keepalive", {
"reply_channel": self.reply_channel,
})
self.last_keepalive = time.time()
else:
logger.debug("WebSocket closed before handshake established")
class WebSocketFactory(WebSocketServerFactory):