mirror of
https://github.com/django/daphne.git
synced 2025-07-10 16:02:18 +03:00
Switch to new explicit WebSocket acceptance
This commit is contained in:
parent
1bdf2a7518
commit
685f3aed1e
|
@ -24,6 +24,18 @@ class AccessLogGenerator(object):
|
||||||
length=details['size'],
|
length=details['size'],
|
||||||
)
|
)
|
||||||
# Websocket requests
|
# Websocket requests
|
||||||
|
elif protocol == "websocket" and action == "connecting":
|
||||||
|
self.write_entry(
|
||||||
|
host=details['client'],
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
request="WSCONNECTING %(path)s" % details,
|
||||||
|
)
|
||||||
|
elif protocol == "websocket" and action == "rejected":
|
||||||
|
self.write_entry(
|
||||||
|
host=details['client'],
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
request="WSREJECT %(path)s" % details,
|
||||||
|
)
|
||||||
elif protocol == "websocket" and action == "connected":
|
elif protocol == "websocket" and action == "connected":
|
||||||
self.write_entry(
|
self.write_entry(
|
||||||
host=details['client'],
|
host=details['client'],
|
||||||
|
|
|
@ -281,12 +281,13 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
|
|
||||||
protocol = HTTPProtocol
|
protocol = HTTPProtocol
|
||||||
|
|
||||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""):
|
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30):
|
||||||
http.HTTPFactory.__init__(self)
|
http.HTTPFactory.__init__(self)
|
||||||
self.channel_layer = channel_layer
|
self.channel_layer = channel_layer
|
||||||
self.action_logger = action_logger
|
self.action_logger = action_logger
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.websocket_timeout = websocket_timeout
|
self.websocket_timeout = websocket_timeout
|
||||||
|
self.websocket_connect_timeout = websocket_connect_timeout
|
||||||
self.ping_interval = ping_interval
|
self.ping_interval = ping_interval
|
||||||
# We track all sub-protocols for response channel mapping
|
# We track all sub-protocols for response channel mapping
|
||||||
self.reply_protocols = {}
|
self.reply_protocols = {}
|
||||||
|
@ -304,21 +305,37 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest):
|
if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest):
|
||||||
self.reply_protocols[channel].serverResponse(message)
|
self.reply_protocols[channel].serverResponse(message)
|
||||||
elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
||||||
# Ensure the message is a valid WebSocket one
|
# Switch depending on current socket state
|
||||||
unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"}
|
protocol = self.reply_protocols[channel]
|
||||||
if unknown_message_keys:
|
# See if the message is valid
|
||||||
|
non_accept_keys = set(message.keys()) - {"accept"}
|
||||||
|
non_send_keys = set(message.keys()) - {"bytes", "text", "close"}
|
||||||
|
if non_accept_keys and non_send_keys:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Got invalid WebSocket reply message on %s - contains unknown keys %s" % (
|
"Got invalid WebSocket reply message on %s - "
|
||||||
|
"contains unknown keys %s (looking for either {'accept'} or {'text', 'bytes', 'close'})" % (
|
||||||
channel,
|
channel,
|
||||||
unknown_message_keys,
|
unknown_message_keys,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if "accept" in message:
|
||||||
|
if protocol.state != protocol.STATE_CONNECTING:
|
||||||
|
raise ValueError(
|
||||||
|
"Got invalid WebSocket connection reply message on %s - websocket is not in handshake phase" % (
|
||||||
|
channel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if message['accept']:
|
||||||
|
protocol.serverAccept()
|
||||||
|
else:
|
||||||
|
protocol.serverReject()
|
||||||
|
else:
|
||||||
if message.get("bytes", None):
|
if message.get("bytes", None):
|
||||||
self.reply_protocols[channel].serverSend(message["bytes"], True)
|
protocol.serverSend(message["bytes"], True)
|
||||||
if message.get("text", None):
|
if message.get("text", None):
|
||||||
self.reply_protocols[channel].serverSend(message["text"], False)
|
protocol.serverSend(message["text"], False)
|
||||||
if message.get("close", False):
|
if message.get("close", False):
|
||||||
self.reply_protocols[channel].serverClose()
|
protocol.serverClose()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot dispatch message on channel %r" % channel)
|
raise ValueError("Cannot dispatch message on channel %r" % channel)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import six
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from six.moves.urllib_parse import unquote, urlencode
|
from six.moves.urllib_parse import unquote, urlencode
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
|
||||||
|
|
||||||
|
@ -27,6 +28,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
def onConnect(self, request):
|
def onConnect(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.packets_received = 0
|
self.packets_received = 0
|
||||||
|
self.protocol_to_accept = None
|
||||||
self.socket_opened = time.time()
|
self.socket_opened = time.time()
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
try:
|
try:
|
||||||
|
@ -78,8 +80,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
ws_protocol = protocol
|
ws_protocol = protocol
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Work out what subprotocol we will accept, if any
|
||||||
if ws_protocol and ws_protocol in self.factory.protocols:
|
if ws_protocol and ws_protocol in self.factory.protocols:
|
||||||
return ws_protocol
|
self.protocol_to_accept = ws_protocol
|
||||||
|
else:
|
||||||
|
self.protocol_to_accept = None
|
||||||
|
|
||||||
|
# Send over the connect message
|
||||||
|
try:
|
||||||
|
self.channel_layer.send("websocket.connect", self.request_info)
|
||||||
|
except self.channel_layer.ChannelFull:
|
||||||
|
# You have to consume websocket.connect according to the spec,
|
||||||
|
# so drop the connection.
|
||||||
|
self.muted = True
|
||||||
|
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
||||||
|
# Send code 1013 "try again later" with close.
|
||||||
|
raise ConnectionDeny(code=503, reason="Connection queue at capacity")
|
||||||
|
else:
|
||||||
|
self.factory.log_action("websocket", "connecting", {
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Make a deferred and return it - we'll either call it or err it later on
|
||||||
|
self.handshake_deferred = defer.Deferred()
|
||||||
|
return self.handshake_deferred
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unquote(cls, value):
|
def unquote(cls, value):
|
||||||
|
@ -93,17 +118,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
|
|
||||||
def onOpen(self):
|
def onOpen(self):
|
||||||
# Send news that this channel is open
|
# Send news that this channel is open
|
||||||
logger.debug("WebSocket open for %s", self.reply_channel)
|
logger.debug("WebSocket %s open and established", self.reply_channel)
|
||||||
try:
|
|
||||||
self.channel_layer.send("websocket.connect", self.request_info)
|
|
||||||
except self.channel_layer.ChannelFull:
|
|
||||||
# You have to consume websocket.connect according to the spec,
|
|
||||||
# so drop the connection.
|
|
||||||
self.muted = True
|
|
||||||
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
|
|
||||||
# Send code 1013 "try again later" with close.
|
|
||||||
self.sendCloseFrame(code=1013, isReply=False)
|
|
||||||
else:
|
|
||||||
self.factory.log_action("websocket", "connected", {
|
self.factory.log_action("websocket", "connected", {
|
||||||
"path": self.request.path,
|
"path": self.request.path,
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
|
@ -140,10 +155,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
# Send code 1013 "try again later" with close.
|
# Send code 1013 "try again later" with close.
|
||||||
self.sendCloseFrame(code=1013, isReply=False)
|
self.sendCloseFrame(code=1013, isReply=False)
|
||||||
|
|
||||||
|
def serverAccept(self):
|
||||||
|
"""
|
||||||
|
Called when we get a message saying to accept the connection.
|
||||||
|
"""
|
||||||
|
self.handshake_deferred.callback(self.protocol_to_accept)
|
||||||
|
logger.debug("WebSocket %s accepted by application", self.reply_channel)
|
||||||
|
|
||||||
|
def serverReject(self):
|
||||||
|
"""
|
||||||
|
Called when we get a message saying to accept the connection.
|
||||||
|
"""
|
||||||
|
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||||
|
self.cleanup()
|
||||||
|
logger.debug("WebSocket %s rejected by application", self.reply_channel)
|
||||||
|
self.factory.log_action("websocket", "rejected", {
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||||
|
})
|
||||||
|
|
||||||
def serverSend(self, content, binary=False):
|
def serverSend(self, content, binary=False):
|
||||||
"""
|
"""
|
||||||
Server-side channel message to send a message.
|
Server-side channel message to send a message.
|
||||||
"""
|
"""
|
||||||
|
if self.state == self.STATE_CONNECTING:
|
||||||
|
self.serverAccept()
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
|
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
|
||||||
if binary:
|
if binary:
|
||||||
|
@ -158,9 +194,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.sendClose()
|
self.sendClose()
|
||||||
|
|
||||||
def onClose(self, wasClean, code, reason):
|
def onClose(self, wasClean, code, reason):
|
||||||
|
self.cleanup()
|
||||||
if hasattr(self, "reply_channel"):
|
if hasattr(self, "reply_channel"):
|
||||||
logger.debug("WebSocket closed for %s", self.reply_channel)
|
logger.debug("WebSocket closed for %s", self.reply_channel)
|
||||||
del self.factory.reply_protocols[self.reply_channel]
|
|
||||||
try:
|
try:
|
||||||
if not self.muted:
|
if not self.muted:
|
||||||
self.channel_layer.send("websocket.disconnect", {
|
self.channel_layer.send("websocket.disconnect", {
|
||||||
|
@ -178,6 +214,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
else:
|
else:
|
||||||
logger.debug("WebSocket closed before handshake established")
|
logger.debug("WebSocket closed before handshake established")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""
|
||||||
|
Call to clean up this socket after it's closed.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "reply_channel"):
|
||||||
|
del self.factory.reply_protocols[self.reply_channel]
|
||||||
|
|
||||||
def duration(self):
|
def duration(self):
|
||||||
"""
|
"""
|
||||||
Returns the time since the socket was opened
|
Returns the time since the socket was opened
|
||||||
|
@ -186,8 +229,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
|
|
||||||
def check_ping(self):
|
def check_ping(self):
|
||||||
"""
|
"""
|
||||||
Checks to see if we should send a keepalive ping.
|
Checks to see if we should send a keepalive ping/deny socket connection
|
||||||
"""
|
"""
|
||||||
|
# If we're still connecting, deny the connection
|
||||||
|
if self.state == self.STATE_CONNECTING:
|
||||||
|
if self.duration() > self.main_factory.websocket_connect_timeout:
|
||||||
|
self.serverReject()
|
||||||
|
elif self.state == self.STATE_OPEN:
|
||||||
if (time.time() - self.last_data) > self.main_factory.ping_interval:
|
if (time.time() - self.last_data) > self.main_factory.ping_interval:
|
||||||
self._sendAutoPing()
|
self._sendAutoPing()
|
||||||
self.last_data = time.time()
|
self.last_data = time.time()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user