Switch to new explicit WebSocket acceptance

This commit is contained in:
Andrew Godwin 2016-10-05 12:08:03 -07:00
parent 1bdf2a7518
commit 685f3aed1e
3 changed files with 109 additions and 32 deletions

View File

@ -24,6 +24,18 @@ class AccessLogGenerator(object):
length=details['size'],
)
# Websocket requests
elif protocol == "websocket" and action == "connecting":
self.write_entry(
host=details['client'],
date=datetime.datetime.now(),
request="WSCONNECTING %(path)s" % details,
)
elif protocol == "websocket" and action == "rejected":
self.write_entry(
host=details['client'],
date=datetime.datetime.now(),
request="WSREJECT %(path)s" % details,
)
elif protocol == "websocket" and action == "connected":
self.write_entry(
host=details['client'],

View File

@ -281,12 +281,13 @@ class HTTPFactory(http.HTTPFactory):
protocol = HTTPProtocol
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""):
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30):
http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer
self.action_logger = action_logger
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.ping_interval = ping_interval
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
@ -304,21 +305,37 @@ class HTTPFactory(http.HTTPFactory):
if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest):
self.reply_protocols[channel].serverResponse(message)
elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
# Ensure the message is a valid WebSocket one
unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"}
if unknown_message_keys:
# Switch depending on current socket state
protocol = self.reply_protocols[channel]
# See if the message is valid
non_accept_keys = set(message.keys()) - {"accept"}
non_send_keys = set(message.keys()) - {"bytes", "text", "close"}
if non_accept_keys and non_send_keys:
raise ValueError(
"Got invalid WebSocket reply message on %s - contains unknown keys %s" % (
"Got invalid WebSocket reply message on %s - "
"contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % (
channel,
unknown_message_keys,
)
)
if message.get("bytes", None):
self.reply_protocols[channel].serverSend(message["bytes"], True)
if message.get("text", None):
self.reply_protocols[channel].serverSend(message["text"], False)
if message.get("close", False):
self.reply_protocols[channel].serverClose()
if "accept" in message:
if protocol.state != protocol.STATE_CONNECTING:
raise ValueError(
"Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % (
channel,
)
)
if message['accept']:
protocol.serverAccept()
else:
protocol.serverReject()
else:
if message.get("bytes", None):
protocol.serverSend(message["bytes"], True)
if message.get("text", None):
protocol.serverSend(message["text"], False)
if message.get("close", False):
protocol.serverClose()
else:
raise ValueError("Cannot dispatch message on channel %r" % channel)

View File

@ -5,6 +5,7 @@ import six
import time
import traceback
from six.moves.urllib_parse import unquote, urlencode
from twisted.internet import defer
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
@ -27,6 +28,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onConnect(self, request):
self.request = request
self.packets_received = 0
self.protocol_to_accept = None
self.socket_opened = time.time()
self.last_data = time.time()
try:
@ -78,8 +80,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
ws_protocol = protocol
break
# Work out what subprotocol we will accept, if any
if ws_protocol and ws_protocol in self.factory.protocols:
return ws_protocol
self.protocol_to_accept = ws_protocol
else:
self.protocol_to_accept = None
# Send over the connect message
try:
self.channel_layer.send("websocket.connect", self.request_info)
except self.channel_layer.ChannelFull:
# You have to consume websocket.connect according to the spec,
# so drop the connection.
self.muted = True
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
# Send code 1013 "try again later" with close.
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
else:
self.factory.log_action("websocket", "connecting", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
# Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred()
return self.handshake_deferred
@classmethod
def unquote(cls, value):
@ -93,21 +118,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket open for %s", self.reply_channel)
try:
self.channel_layer.send("websocket.connect", self.request_info)
except self.channel_layer.ChannelFull:
# You have to consume websocket.connect according to the spec,
# so drop the connection.
self.muted = True
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
# Send code 1013 "try again later" with close.
self.sendCloseFrame(code=1013, isReply=False)
else:
self.factory.log_action("websocket", "connected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
logger.debug("WebSocket %s open and established", self.reply_channel)
self.factory.log_action("websocket", "connected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
def onMessage(self, payload, isBinary):
# If we're muted, do nothing.
@ -140,10 +155,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
# Send code 1013 "try again later" with close.
self.sendCloseFrame(code=1013, isReply=False)
def serverAccept(self):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(self.protocol_to_accept)
logger.debug("WebSocket %s accepted by application", self.reply_channel)
def serverReject(self):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
self.cleanup()
logger.debug("WebSocket %s rejected by application", self.reply_channel)
self.factory.log_action("websocket", "rejected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
def serverSend(self, content, binary=False):
"""
Server-side channel message to send a message.
"""
if self.state == self.STATE_CONNECTING:
self.serverAccept()
self.last_data = time.time()
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
if binary:
@ -158,9 +194,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.sendClose()
def onClose(self, wasClean, code, reason):
self.cleanup()
if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
try:
if not self.muted:
self.channel_layer.send("websocket.disconnect", {
@ -178,6 +214,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
else:
logger.debug("WebSocket closed before handshake established")
def cleanup(self):
"""
Call to clean up this socket after it's closed.
"""
if hasattr(self, "reply_channel"):
del self.factory.reply_protocols[self.reply_channel]
def duration(self):
"""
Returns the time since the socket was opened
@ -186,11 +229,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
def check_ping(self):
"""
Checks to see if we should send a keepalive ping.
Checks to see if we should send a keepalive ping/deny socket connection
"""
if (time.time() - self.last_data) > self.main_factory.ping_interval:
self._sendAutoPing()
self.last_data = time.time()
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.main_factory.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if (time.time() - self.last_data) > self.main_factory.ping_interval:
self._sendAutoPing()
self.last_data = time.time()
class WebSocketFactory(WebSocketServerFactory):