mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Fix WebSocket protocol to correctly provide more info.
Also removed keepalive code, since it's no longer how we do things.
This commit is contained in:
parent
40ed9625c3
commit
1f46fdff12
|
@ -40,13 +40,6 @@ class WebRequest(http.Request):
|
|||
self.query_string = ""
|
||||
if b"?" in self.uri:
|
||||
self.query_string = self.uri.split(b"?", 1)[1]
|
||||
# Sanitize headers
|
||||
self.clean_headers = {}
|
||||
for name, value in self.requestHeaders.getAllRawHeaders():
|
||||
# Prevent CVE-2015-0219
|
||||
if b"_" in name:
|
||||
continue
|
||||
self.clean_headers[name.lower().decode("latin1")] = value[0]
|
||||
# Is it WebSocket? IS IT?!
|
||||
if upgrade_header == "websocket":
|
||||
# Make WebSocket protocol to hand off to
|
||||
|
@ -54,6 +47,7 @@ class WebRequest(http.Request):
|
|||
if not protocol:
|
||||
# If protocol creation fails, we signal "internal server error"
|
||||
self.setResponseCode(500)
|
||||
logger.warn("Could not make WebSocket protocol")
|
||||
self.finish()
|
||||
# Port across transport
|
||||
protocol.set_main_factory(self.factory)
|
||||
|
@ -72,12 +66,19 @@ class WebRequest(http.Request):
|
|||
data += self.content.read()
|
||||
protocol.dataReceived(data)
|
||||
# Remove our HTTP reply channel association
|
||||
logging.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
|
||||
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
|
||||
self.factory.reply_protocols[self.reply_channel] = None
|
||||
self.reply_channel = None
|
||||
# Boring old HTTP.
|
||||
else:
|
||||
logging.debug("HTTP %s request for %s", self.method, self.reply_channel)
|
||||
# Sanitize and decode headers
|
||||
self.clean_headers = {}
|
||||
for name, value in self.requestHeaders.getAllRawHeaders():
|
||||
# Prevent CVE-2015-0219
|
||||
if b"_" in name:
|
||||
continue
|
||||
self.clean_headers[name.lower().decode("latin1")] = value[0]
|
||||
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
|
||||
self.content.seek(0, 0)
|
||||
# Send message
|
||||
self.factory.channel_layer.send("http.request", {
|
||||
|
@ -100,7 +101,7 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
if self.reply_channel:
|
||||
del self.channel.factory.reply_protocols[self.reply_channel]
|
||||
logging.debug("HTTP disconnect for %s", self.reply_channel)
|
||||
logger.debug("HTTP disconnect for %s", self.reply_channel)
|
||||
http.Request.connectionLost(self, reason)
|
||||
|
||||
def serverResponse(self, message):
|
||||
|
@ -122,9 +123,9 @@ class WebRequest(http.Request):
|
|||
# End if there's no more content
|
||||
if not message.get("more_content", False):
|
||||
self.finish()
|
||||
logging.debug("HTTP %s response for %s", message['status'], self.reply_channel)
|
||||
logger.debug("HTTP %s response for %s", message['status'], self.reply_channel)
|
||||
else:
|
||||
logging.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
|
||||
logger.debug("HTTP %s response chunk for %s", message['status'], self.reply_channel)
|
||||
|
||||
|
||||
class HTTPProtocol(http.HTTPChannel):
|
||||
|
|
|
@ -15,7 +15,6 @@ class Server(object):
|
|||
self.factory = HTTPFactory(self.channel_layer)
|
||||
reactor.listenTCP(self.port, self.factory, interface=self.host)
|
||||
reactor.callInThread(self.backend_reader)
|
||||
#reactor.callLater(1, self.keepalive_sender)
|
||||
reactor.run()
|
||||
|
||||
def backend_reader(self):
|
||||
|
@ -39,14 +38,3 @@ class Server(object):
|
|||
continue
|
||||
# Deal with the message
|
||||
self.factory.dispatch_reply(channel, message)
|
||||
|
||||
def keepalive_sender(self):
|
||||
"""
|
||||
Sends keepalive messages for open WebSockets every
|
||||
(channel_backend expiry / 2) seconds.
|
||||
"""
|
||||
expiry_window = int(self.channel_layer.group_expiry / 2)
|
||||
for protocol in self.factory.reply_protocols.values():
|
||||
if time.time() - protocol.last_keepalive > expiry_window:
|
||||
protocol.sendKeepalive()
|
||||
reactor.callLater(1, self.keepalive_sender)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import urllib
|
||||
|
||||
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
||||
|
||||
|
@ -19,21 +21,37 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.channel_layer = self.main_factory.channel_layer
|
||||
|
||||
def onConnect(self, request):
|
||||
self.request_info = {
|
||||
"path": request.path,
|
||||
"headers": self.clean_headers,
|
||||
"query_string": request.query_string,
|
||||
"client": [request.client.host, request.client.port],
|
||||
"server": [request.host.host, request.host.port],
|
||||
}
|
||||
|
||||
def onOpen(self):
|
||||
try:
|
||||
# Sanitize and decode headers
|
||||
clean_headers = {}
|
||||
for name, value in request.headers.items():
|
||||
# Prevent CVE-2015-0219
|
||||
if "_" in name:
|
||||
continue
|
||||
clean_headers[name.lower()] = value[0].encode("latin1")
|
||||
# Reconstruct query string
|
||||
# TODO: get autobahn to provide it raw
|
||||
query_string = urllib.urlencode(request.params)
|
||||
# Make sending channel
|
||||
self.reply_channel = self.channel_layer.new_channel("!websocket.send.?")
|
||||
self.request_info["reply_channel"] = self.reply_channel
|
||||
self.last_keepalive = time.time()
|
||||
# Tell main factory about it
|
||||
self.main_factory.reply_protocols[self.reply_channel] = self
|
||||
# Make initial request info dict from request (we only have it here)
|
||||
self.request_info = {
|
||||
"path": request.path,
|
||||
"headers": clean_headers,
|
||||
"query_string": query_string,
|
||||
"client": [self.transport.getPeer().host, self.transport.getPeer().port],
|
||||
"server": [self.transport.getHost().host, self.transport.getHost().port],
|
||||
"reply_channel": self.reply_channel,
|
||||
}
|
||||
except:
|
||||
# Exceptions here are not displayed right, just 500.
|
||||
# Turn them into an ERROR log.
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
logger.debug("WebSocket open for %s", self.reply_channel)
|
||||
self.channel_layer.send("websocket.connect", self.request_info)
|
||||
|
@ -68,21 +86,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.sendClose()
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||
if hasattr(self, "reply_channel"):
|
||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||
del self.factory.reply_protocols[self.reply_channel]
|
||||
self.channel_layer.send("websocket.disconnect", {
|
||||
"reply_channel": self.reply_channel,
|
||||
})
|
||||
|
||||
def sendKeepalive(self):
|
||||
"""
|
||||
Sends a keepalive packet on the keepalive channel.
|
||||
"""
|
||||
self.channel_layer.send("websocket.keepalive", {
|
||||
"reply_channel": self.reply_channel,
|
||||
})
|
||||
self.last_keepalive = time.time()
|
||||
else:
|
||||
logger.debug("WebSocket closed before handshake established")
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
|
|
Loading…
Reference in New Issue
Block a user