diff --git a/daphne/cli.py b/daphne/cli.py index 1a4519e..543c274 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -54,6 +54,12 @@ class CommandLineInterface(object): help='How long to wait for worker server before timing out HTTP connections', default=120, ) + self.parser.add_argument( + '--ping-interval', + type=int, + help='The number of seconds a WebSocket must be idle before a keepalive ping is sent', + default=20, + ) self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', @@ -100,4 +106,5 @@ class CommandLineInterface(object): port=args.port, unix_socket=args.unix_socket, http_timeout=args.http_timeout, + ping_interval=args.ping_interval, ).run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 3f7b967..2c49887 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -219,12 +219,13 @@ class HTTPFactory(http.HTTPFactory): protocol = HTTPProtocol - def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400): + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): http.HTTPFactory.__init__(self) self.channel_layer = channel_layer self.action_logger = action_logger self.timeout = timeout self.websocket_timeout = websocket_timeout + self.ping_interval = ping_interval # We track all sub-protocols for response channel mapping self.reply_protocols = {} # Make a factory for WebSocket protocols @@ -264,6 +265,11 @@ class HTTPFactory(http.HTTPFactory): # Web timeout checking if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") - # WebSocket timeout checking - elif isinstance(protocol, WebSocketProtocol) and protocol.duration() > self.websocket_timeout: - protocol.serverClose() + # WebSocket timeout checking and keepalive ping sending + elif isinstance(protocol, WebSocketProtocol): + # Timeout check + if protocol.duration() > self.websocket_timeout: + protocol.serverClose() + # Ping check + else: + protocol.check_ping() diff --git a/daphne/server.py b/daphne/server.py index 6a06dc3..90bc6f2 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -18,6 +18,7 @@ class Server(object): action_logger=None, http_timeout=120, websocket_timeout=None, + ping_interval=20, ): self.channel_layer = channel_layer self.host = host @@ -26,6 +27,7 @@ class Server(object): self.signal_handlers = signal_handlers self.action_logger = action_logger self.http_timeout = http_timeout + self.ping_interval = ping_interval # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) @@ -36,6 +38,7 @@ class Server(object): self.action_logger, timeout=self.http_timeout, websocket_timeout=self.websocket_timeout, + ping_interval=self.ping_interval, ) if self.unix_socket: reactor.listenUNIX(self.unix_socket, self.factory) @@ -68,7 +71,8 @@ class Server(object): def timeout_checker(self): """ - Called periodically to enforce timeout rules on all connections + Called periodically to enforce timeout rules on all connections. + Also checks pings at the same time. """ self.factory.check_timeouts() reactor.callLater(2, self.timeout_checker) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index d1f2e8f..fb7d095 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -24,6 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol): self.request = request self.packets_received = 0 self.socket_opened = time.time() + self.last_data = time.time() try: # Sanitize and decode headers clean_headers = {} @@ -68,6 +69,7 @@ class WebSocketProtocol(WebSocketServerProtocol): def onMessage(self, payload, isBinary): logger.debug("WebSocket incoming packet on %s", self.reply_channel) self.packets_received += 1 + self.last_data = time.time() if isBinary: self.channel_layer.send("websocket.receive", { "reply_channel": self.reply_channel, @@ -87,6 +89,7 @@ class WebSocketProtocol(WebSocketServerProtocol): """ Server-side channel message to send a message. """ + self.last_data = time.time() logger.debug("Sent WebSocket packet to client for %s", self.reply_channel) if binary: self.sendMessage(content, binary) @@ -121,6 +124,14 @@ class WebSocketProtocol(WebSocketServerProtocol): """ return time.time() - self.socket_opened + def check_ping(self): + """ + Checks to see if we should send a keepalive ping. + """ + if (time.time() - self.last_data) > self.main_factory.ping_interval: + self.sendPing() + self.last_data = time.time() + class WebSocketFactory(WebSocketServerFactory): """