mirror of
https://github.com/django/daphne.git
synced 2024-11-22 07:56:34 +03:00
Fix WebSocket protocol to correctly provide more info.
Also removed keepalive code, since it's no longer how we do things.
This commit is contained in:
parent
40ed9625c3
commit
1f46fdff12
|
@ -40,13 +40,6 @@ class WebRequest(http.Request):
|
||||||
self.query_string = ""
|
self.query_string = ""
|
||||||
if b"?" in self.uri:
|
if b"?" in self.uri:
|
||||||
self.query_string = self.uri.split(b"?", 1)[1]
|
self.query_string = self.uri.split(b"?", 1)[1]
|
||||||
# Sanitize headers
|
|
||||||
self.clean_headers = {}
|
|
||||||
for name, value in self.requestHeaders.getAllRawHeaders():
|
|
||||||
# Prevent CVE-2015-0219
|
|
||||||
if b"_" in name:
|
|
||||||
continue
|
|
||||||
self.clean_headers[name.lower().decode("latin1")] = value[0]
|
|
||||||
# Is it WebSocket? IS IT?!
|
# Is it WebSocket? IS IT?!
|
||||||
if upgrade_header == "websocket":
|
if upgrade_header == "websocket":
|
||||||
# Make WebSocket protocol to hand off to
|
# Make WebSocket protocol to hand off to
|
||||||
|
@ -54,6 +47,7 @@ class WebRequest(http.Request):
|
||||||
if not protocol:
|
if not protocol:
|
||||||
# If protocol creation fails, we signal "internal server error"
|
# If protocol creation fails, we signal "internal server error"
|
||||||
self.setResponseCode(500)
|
self.setResponseCode(500)
|
||||||
|
logger.warn("Could not make WebSocket protocol")
|
||||||
self.finish()
|
self.finish()
|
||||||
# Port across transport
|
# Port across transport
|
||||||
protocol.set_main_factory(self.factory)
|
protocol.set_main_factory(self.factory)
|
||||||
|
@ -72,12 +66,19 @@ class WebRequest(http.Request):
|
||||||
data += self.content.read()
|
data += self.content.read()
|
||||||
protocol.dataReceived(data)
|
protocol.dataReceived(data)
|
||||||
# Remove our HTTP reply channel association
|
# Remove our HTTP reply channel association
|
||||||
logging.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
|
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
|
||||||
self.factory.reply_protocols[self.reply_channel] = None
|
self.factory.reply_protocols[self.reply_channel] = None
|
||||||
self.reply_channel = None
|
self.reply_channel = None
|
||||||
# Boring old HTTP.
|
# Boring old HTTP.
|
||||||
else:
|
else:
|
||||||
logging.debug("HTTP %s request for %s", self.method, self.reply_channel)
|
# Sanitize and decode headers
|
||||||
|
self.clean_headers = {}
|
||||||
|
for name, value in self.requestHeaders.getAllRawHeaders():
|
||||||
|
# Prevent CVE-2015-0219
|
||||||
|
if b"_" in name:
|
||||||
|
continue
|
||||||
|
self.clean_headers[name.lower().decode("latin1")] = value[0]
|
||||||
|
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
|
||||||
self.content.seek(0, 0)
|
self.content.seek(0, 0)
|
||||||
# Send message
|
# Send message
|
||||||
self.factory.channel_layer.send("http.request", {
|
self.factory.channel_layer.send("http.request", {
|
||||||
|
@ -100,7 +101,7 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
if self.reply_channel:
|
if self.reply_channel:
|
||||||
del self.channel.factory.reply_protocols[self.reply_channel]
|
del self.channel.factory.reply_protocols[self.reply_channel]
|
||||||
logging.debug("HTTP disconnect for %s", self.reply_channel)
|
logger.debug("HTTP disconnect for %s", self.reply_channel)
|
||||||
http.Request.connectionLost(self, reason)
|
http.Request.connectionLost(self, reason)
|
||||||
|
|
||||||
def serverResponse(self, message):
|
def serverResponse(self, message):
|
||||||
|
@ -122,9 +123,9 @@ class WebRequest(http.Request):
|
||||||
# End if there's no more content
|
# End if there's no more content
|
||||||
if not message.get("more_content", False):
|
if not message.get("more_content", False):
|
||||||
self.finish()
|
self.finish()
|
||||||
logging.debug("HTTP %s response for %s", message['status'], self.reply_channel)
|
logger.debug("HTTP %s response for %s", message['status'], self.reply_channel)
|
||||||
else:
|
else:
|
||||||
logging.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
|
logger.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
|
||||||
|
|
||||||
|
|
||||||
class HTTPProtocol(http.HTTPChannel):
|
class HTTPProtocol(http.HTTPChannel):
|
||||||
|
|
|
@ -15,7 +15,6 @@ class Server(object):
|
||||||
self.factory = HTTPFactory(self.channel_layer)
|
self.factory = HTTPFactory(self.channel_layer)
|
||||||
reactor.listenTCP(self.port, self.factory, interface=self.host)
|
reactor.listenTCP(self.port, self.factory, interface=self.host)
|
||||||
reactor.callInThread(self.backend_reader)
|
reactor.callInThread(self.backend_reader)
|
||||||
#reactor.callLater(1, self.keepalive_sender)
|
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
||||||
def backend_reader(self):
|
def backend_reader(self):
|
||||||
|
@ -39,14 +38,3 @@ class Server(object):
|
||||||
continue
|
continue
|
||||||
# Deal with the message
|
# Deal with the message
|
||||||
self.factory.dispatch_reply(channel, message)
|
self.factory.dispatch_reply(channel, message)
|
||||||
|
|
||||||
def keepalive_sender(self):
|
|
||||||
"""
|
|
||||||
Sends keepalive messages for open WebSockets every
|
|
||||||
(channel_backend expiry / 2) seconds.
|
|
||||||
"""
|
|
||||||
expiry_window = int(self.channel_layer.group_expiry / 2)
|
|
||||||
for protocol in self.factory.reply_protocols.values():
|
|
||||||
if time.time() - protocol.last_keepalive > expiry_window:
|
|
||||||
protocol.sendKeepalive()
|
|
||||||
reactor.callLater(1, self.keepalive_sender)
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import urllib
|
||||||
|
|
||||||
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
||||||
|
|
||||||
|
@ -19,21 +21,37 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.channel_layer = self.main_factory.channel_layer
|
self.channel_layer = self.main_factory.channel_layer
|
||||||
|
|
||||||
def onConnect(self, request):
|
def onConnect(self, request):
|
||||||
self.request_info = {
|
try:
|
||||||
"path": request.path,
|
# Sanitize and decode headers
|
||||||
"headers": self.clean_headers,
|
clean_headers = {}
|
||||||
"query_string": request.query_string,
|
for name, value in request.headers.items():
|
||||||
"client": [request.client.host, request.client.port],
|
# Prevent CVE-2015-0219
|
||||||
"server": [request.host.host, request.host.port],
|
if "_" in name:
|
||||||
}
|
continue
|
||||||
|
clean_headers[name.lower()] = value[0].encode("latin1")
|
||||||
def onOpen(self):
|
# Reconstruct query string
|
||||||
|
# TODO: get autobahn to provide it raw
|
||||||
|
query_string = urllib.urlencode(request.params)
|
||||||
# Make sending channel
|
# Make sending channel
|
||||||
self.reply_channel = self.channel_layer.new_channel("!websocket.send.?")
|
self.reply_channel = self.channel_layer.new_channel("!websocket.send.?")
|
||||||
self.request_info["reply_channel"] = self.reply_channel
|
|
||||||
self.last_keepalive = time.time()
|
|
||||||
# Tell main factory about it
|
# Tell main factory about it
|
||||||
self.main_factory.reply_protocols[self.reply_channel] = self
|
self.main_factory.reply_protocols[self.reply_channel] = self
|
||||||
|
# Make initial request info dict from request (we only have it here)
|
||||||
|
self.request_info = {
|
||||||
|
"path": request.path,
|
||||||
|
"headers": clean_headers,
|
||||||
|
"query_string": query_string,
|
||||||
|
"client": [self.transport.getPeer().host, self.transport.getPeer().port],
|
||||||
|
"server": [self.transport.getHost().host, self.transport.getHost().port],
|
||||||
|
"reply_channel": self.reply_channel,
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
# Exceptions here are not displayed right, just 500.
|
||||||
|
# Turn them into an ERROR log.
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
def onOpen(self):
|
||||||
# Send news that this channel is open
|
# Send news that this channel is open
|
||||||
logger.debug("WebSocket open for %s", self.reply_channel)
|
logger.debug("WebSocket open for %s", self.reply_channel)
|
||||||
self.channel_layer.send("websocket.connect", self.request_info)
|
self.channel_layer.send("websocket.connect", self.request_info)
|
||||||
|
@ -68,21 +86,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.sendClose()
|
self.sendClose()
|
||||||
|
|
||||||
def onClose(self, wasClean, code, reason):
|
def onClose(self, wasClean, code, reason):
|
||||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
|
||||||
if hasattr(self, "reply_channel"):
|
if hasattr(self, "reply_channel"):
|
||||||
|
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||||
del self.factory.reply_protocols[self.reply_channel]
|
del self.factory.reply_protocols[self.reply_channel]
|
||||||
self.channel_layer.send("websocket.disconnect", {
|
self.channel_layer.send("websocket.disconnect", {
|
||||||
"reply_channel": self.reply_channel,
|
"reply_channel": self.reply_channel,
|
||||||
})
|
})
|
||||||
|
else:
|
||||||
def sendKeepalive(self):
|
logger.debug("WebSocket closed before handshake established")
|
||||||
"""
|
|
||||||
Sends a keepalive packet on the keepalive channel.
|
|
||||||
"""
|
|
||||||
self.channel_layer.send("websocket.keepalive", {
|
|
||||||
"reply_channel": self.reply_channel,
|
|
||||||
})
|
|
||||||
self.last_keepalive = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketFactory(WebSocketServerFactory):
|
class WebSocketFactory(WebSocketServerFactory):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user