Add WebSocket keepalive ping support with configurable interval

This commit is contained in:
Andrew Godwin 2016-03-29 03:28:05 -07:00
parent 5cb28d1e10
commit ca35d6c18f
4 changed files with 33 additions and 5 deletions

View File

@ -54,6 +54,12 @@ class CommandLineInterface(object):
help='How long to wait for worker server before timing out HTTP connections', help='How long to wait for worker server before timing out HTTP connections',
default=120, default=120,
) )
self.parser.add_argument(
'--ping-interval',
type=int,
help='The number of seconds a WebSocket must be idle before a keepalive ping is sent',
default=20,
)
self.parser.add_argument( self.parser.add_argument(
'channel_layer', 'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path', help='The ASGI channel layer instance to use as path.to.module:instance.path',
@ -100,4 +106,5 @@ class CommandLineInterface(object):
port=args.port, port=args.port,
unix_socket=args.unix_socket, unix_socket=args.unix_socket,
http_timeout=args.http_timeout, http_timeout=args.http_timeout,
ping_interval=args.ping_interval,
).run() ).run()

View File

@ -219,12 +219,13 @@ class HTTPFactory(http.HTTPFactory):
protocol = HTTPProtocol protocol = HTTPProtocol
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400): def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.action_logger = action_logger self.action_logger = action_logger
self.timeout = timeout self.timeout = timeout
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.ping_interval = ping_interval
# We track all sub-protocols for response channel mapping # We track all sub-protocols for response channel mapping
self.reply_protocols = {} self.reply_protocols = {}
# Make a factory for WebSocket protocols # Make a factory for WebSocket protocols
@ -264,6 +265,11 @@ class HTTPFactory(http.HTTPFactory):
# Web timeout checking # Web timeout checking
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking # WebSocket timeout checking and keepalive ping sending
elif isinstance(protocol, WebSocketProtocol) and protocol.duration() > self.websocket_timeout: elif isinstance(protocol, WebSocketProtocol):
# Timeout check
if protocol.duration() > self.websocket_timeout:
protocol.serverClose() protocol.serverClose()
# Ping check
else:
protocol.check_ping()

View File

@ -18,6 +18,7 @@ class Server(object):
action_logger=None, action_logger=None,
http_timeout=120, http_timeout=120,
websocket_timeout=None, websocket_timeout=None,
ping_interval=20,
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.host = host self.host = host
@ -26,6 +27,7 @@ class Server(object):
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
self.action_logger = action_logger self.action_logger = action_logger
self.http_timeout = http_timeout self.http_timeout = http_timeout
self.ping_interval = ping_interval
# If they did not provide a websocket timeout, default it to the # If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not. # channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
@ -36,6 +38,7 @@ class Server(object):
self.action_logger, self.action_logger,
timeout=self.http_timeout, timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout, websocket_timeout=self.websocket_timeout,
ping_interval=self.ping_interval,
) )
if self.unix_socket: if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory) reactor.listenUNIX(self.unix_socket, self.factory)
@ -68,7 +71,8 @@ class Server(object):
def timeout_checker(self): def timeout_checker(self):
""" """
Called periodically to enforce timeout rules on all connections Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time.
""" """
self.factory.check_timeouts() self.factory.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)

View File

@ -24,6 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.request = request self.request = request
self.packets_received = 0 self.packets_received = 0
self.socket_opened = time.time() self.socket_opened = time.time()
self.last_data = time.time()
try: try:
# Sanitize and decode headers # Sanitize and decode headers
clean_headers = {} clean_headers = {}
@ -68,6 +69,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onMessage(self, payload, isBinary): def onMessage(self, payload, isBinary):
logger.debug("WebSocket incoming packet on %s", self.reply_channel) logger.debug("WebSocket incoming packet on %s", self.reply_channel)
self.packets_received += 1 self.packets_received += 1
self.last_data = time.time()
if isBinary: if isBinary:
self.channel_layer.send("websocket.receive", { self.channel_layer.send("websocket.receive", {
"reply_channel": self.reply_channel, "reply_channel": self.reply_channel,
@ -87,6 +89,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
Server-side channel message to send a message. Server-side channel message to send a message.
""" """
self.last_data = time.time()
logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
if binary: if binary:
self.sendMessage(content, binary) self.sendMessage(content, binary)
@ -121,6 +124,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
return time.time() - self.socket_opened return time.time() - self.socket_opened
def check_ping(self):
"""
Checks to see if we should send a keepalive ping.
"""
if (time.time() - self.last_data) > self.main_factory.ping_interval:
self.sendPing()
self.last_data = time.time()
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):
""" """