mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Switch to new explicit WebSocket acceptance
This commit is contained in:
parent
1bdf2a7518
commit
685f3aed1e
|
@ -24,6 +24,18 @@ class AccessLogGenerator(object):
|
|||
length=details['size'],
|
||||
)
|
||||
# Websocket requests
|
||||
elif protocol == "websocket" and action == "connecting":
|
||||
self.write_entry(
|
||||
host=details['client'],
|
||||
date=datetime.datetime.now(),
|
||||
request="WSCONNECTING %(path)s" % details,
|
||||
)
|
||||
elif protocol == "websocket" and action == "rejected":
|
||||
self.write_entry(
|
||||
host=details['client'],
|
||||
date=datetime.datetime.now(),
|
||||
request="WSREJECT %(path)s" % details,
|
||||
)
|
||||
elif protocol == "websocket" and action == "connected":
|
||||
self.write_entry(
|
||||
host=details['client'],
|
||||
|
|
|
@ -281,12 +281,13 @@ class HTTPFactory(http.HTTPFactory):
|
|||
|
||||
protocol = HTTPProtocol
|
||||
|
||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""):
|
||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30):
|
||||
http.HTTPFactory.__init__(self)
|
||||
self.channel_layer = channel_layer
|
||||
self.action_logger = action_logger
|
||||
self.timeout = timeout
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_connect_timeout = websocket_connect_timeout
|
||||
self.ping_interval = ping_interval
|
||||
# We track all sub-protocols for response channel mapping
|
||||
self.reply_protocols = {}
|
||||
|
@ -304,21 +305,37 @@ class HTTPFactory(http.HTTPFactory):
|
|||
if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest):
|
||||
self.reply_protocols[channel].serverResponse(message)
|
||||
elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
||||
# Ensure the message is a valid WebSocket one
|
||||
unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"}
|
||||
if unknown_message_keys:
|
||||
# Switch depending on current socket state
|
||||
protocol = self.reply_protocols[channel]
|
||||
# See if the message is valid
|
||||
non_accept_keys = set(message.keys()) - {"accept"}
|
||||
non_send_keys = set(message.keys()) - {"bytes", "text", "close"}
|
||||
if non_accept_keys and non_send_keys:
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains unknown keys %s" % (
|
||||
"Got invalid WebSocket reply message on %s - "
|
||||
"contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % (
|
||||
channel,
|
||||
unknown_message_keys,
|
||||
)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
self.reply_protocols[channel].serverSend(message["bytes"], True)
|
||||
if message.get("text", None):
|
||||
self.reply_protocols[channel].serverSend(message["text"], False)
|
||||
if message.get("close", False):
|
||||
self.reply_protocols[channel].serverClose()
|
||||
if "accept" in message:
|
||||
if protocol.state != protocol.STATE_CONNECTING:
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % (
|
||||
channel,
|
||||
)
|
||||
)
|
||||
if message['accept']:
|
||||
protocol.serverAccept()
|
||||
else:
|
||||
protocol.serverReject()
|
||||
else:
|
||||
if message.get("bytes", None):
|
||||
protocol.serverSend(message["bytes"], True)
|
||||
if message.get("text", None):
|
||||
protocol.serverSend(message["text"], False)
|
||||
if message.get("close", False):
|
||||
protocol.serverClose()
|
||||
else:
|
||||
raise ValueError("Cannot dispatch message on channel %r" % channel)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import six
|
|||
import time
|
||||
import traceback
|
||||
from six.moves.urllib_parse import unquote, urlencode
|
||||
from twisted.internet import defer
|
||||
|
||||
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
||||
|
||||
|
@ -27,6 +28,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
def onConnect(self, request):
|
||||
self.request = request
|
||||
self.packets_received = 0
|
||||
self.protocol_to_accept = None
|
||||
self.socket_opened = time.time()
|
||||
self.last_data = time.time()
|
||||
try:
|
||||
|
@ -78,8 +80,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
ws_protocol = protocol
|
||||
break
|
||||
|
||||
# Work out what subprotocol we will accept, if any
|
||||
if ws_protocol and ws_protocol in self.factory.protocols:
|
||||
return ws_protocol
|
||||
self.protocol_to_accept = ws_protocol
|
||||
else:
|
||||
self.protocol_to_accept = None
|
||||
|
||||
# Send over the connect message
|
||||
try:
|
||||
self.channel_layer.send("websocket.connect", self.request_info)
|
||||
except self.channel_layer.ChannelFull:
|
||||
# You have to consume websocket.connect according to the spec,
|
||||
# so drop the connection.
|
||||
self.muted = True
|
||||
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
||||
# Send code 1013 "try again later" with close.
|
||||
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
|
||||
else:
|
||||
self.factory.log_action("websocket", "connecting", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
||||
# Make a deferred and return it - we'll either call it or err it later on
|
||||
self.handshake_deferred = defer.Deferred()
|
||||
return self.handshake_deferred
|
||||
|
||||
@classmethod
|
||||
def unquote(cls, value):
|
||||
|
@ -93,21 +118,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
|
||||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
logger.debug("WebSocket open for %s", self.reply_channel)
|
||||
try:
|
||||
self.channel_layer.send("websocket.connect", self.request_info)
|
||||
except self.channel_layer.ChannelFull:
|
||||
# You have to consume websocket.connect according to the spec,
|
||||
# so drop the connection.
|
||||
self.muted = True
|
||||
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
||||
# Send code 1013 "try again later" with close.
|
||||
self.sendCloseFrame(code=1013, isReply=False)
|
||||
else:
|
||||
self.factory.log_action("websocket", "connected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
logger.debug("WebSocket %s open and established", self.reply_channel)
|
||||
self.factory.log_action("websocket", "connected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
# If we're muted, do nothing.
|
||||
|
@ -140,10 +155,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
# Send code 1013 "try again later" with close.
|
||||
self.sendCloseFrame(code=1013, isReply=False)
|
||||
|
||||
def serverAccept(self):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.callback(self.protocol_to_accept)
|
||||
logger.debug("WebSocket %s accepted by application", self.reply_channel)
|
||||
|
||||
def serverReject(self):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||
self.cleanup()
|
||||
logger.debug("WebSocket %s rejected by application", self.reply_channel)
|
||||
self.factory.log_action("websocket", "rejected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
|
||||
def serverSend(self, content, binary=False):
|
||||
"""
|
||||
Server-side channel message to send a message.
|
||||
"""
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverAccept()
|
||||
self.last_data = time.time()
|
||||
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
|
||||
if binary:
|
||||
|
@ -158,9 +194,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.sendClose()
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
self.cleanup()
|
||||
if hasattr(self, "reply_channel"):
|
||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||
del self.factory.reply_protocols[self.reply_channel]
|
||||
try:
|
||||
if not self.muted:
|
||||
self.channel_layer.send("websocket.disconnect", {
|
||||
|
@ -178,6 +214,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
else:
|
||||
logger.debug("WebSocket closed before handshake established")
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Call to clean up this socket after it's closed.
|
||||
"""
|
||||
if hasattr(self, "reply_channel"):
|
||||
del self.factory.reply_protocols[self.reply_channel]
|
||||
|
||||
def duration(self):
|
||||
"""
|
||||
Returns the time since the socket was opened
|
||||
|
@ -186,11 +229,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
|
||||
def check_ping(self):
|
||||
"""
|
||||
Checks to see if we should send a keepalive ping.
|
||||
Checks to see if we should send a keepalive ping/deny socket connection
|
||||
"""
|
||||
if (time.time() - self.last_data) > self.main_factory.ping_interval:
|
||||
self._sendAutoPing()
|
||||
self.last_data = time.time()
|
||||
# If we're still connecting, deny the connection
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
if self.duration() > self.main_factory.websocket_connect_timeout:
|
||||
self.serverReject()
|
||||
elif self.state == self.STATE_OPEN:
|
||||
if (time.time() - self.last_data) > self.main_factory.ping_interval:
|
||||
self._sendAutoPing()
|
||||
self.last_data = time.time()
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
|
|
Loading…
Reference in New Issue
Block a user