diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 0dd8fe9..3f7b967 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -219,11 +219,12 @@ class HTTPFactory(http.HTTPFactory): protocol = HTTPProtocol - def __init__(self, channel_layer, action_logger=None, timeout=120): + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400): http.HTTPFactory.__init__(self) self.channel_layer = channel_layer self.action_logger = action_logger self.timeout = timeout + self.websocket_timeout = websocket_timeout # We track all sub-protocols for response channel mapping self.reply_protocols = {} # Make a factory for WebSocket protocols @@ -260,5 +261,9 @@ class HTTPFactory(http.HTTPFactory): taken too long (and so their message is probably expired) """ for protocol in list(self.reply_protocols.values()): + # Web timeout checking if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") + # WebSocket timeout checking + elif isinstance(protocol, WebSocketProtocol) and protocol.duration() > self.websocket_timeout: + protocol.serverClose() diff --git a/daphne/server.py b/daphne/server.py index baf0301..84f9d1d 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -15,7 +15,8 @@ class Server(object): port=8000, signal_handlers=True, action_logger=None, - http_timeout=120 + http_timeout=120, + websocket_timeout=None, ): self.channel_layer = channel_layer self.host = host @@ -23,9 +24,17 @@ class Server(object): self.signal_handlers = signal_handlers self.action_logger = action_logger self.http_timeout = http_timeout + # If they did not provide a websocket timeout, default it to the + # channel layer's group_expiry value if present, or one day if not. + self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) def run(self): - self.factory = HTTPFactory(self.channel_layer, self.action_logger, timeout=self.http_timeout) + self.factory = HTTPFactory( + self.channel_layer, + self.action_logger, + timeout=self.http_timeout, + websocket_timeout=self.websocket_timeout, + ) reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.callLater(0, self.backend_reader) reactor.callLater(2, self.timeout_checker) @@ -54,8 +63,7 @@ class Server(object): def timeout_checker(self): """ - Called periodically to enforce timeout rules on HTTP connections - (but not WebSocket) + Called periodically to enforce timeout rules on all connections """ self.factory.check_timeouts() reactor.callLater(2, self.timeout_checker) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 6dc2d00..d1f2e8f 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -23,6 +23,7 @@ class WebSocketProtocol(WebSocketServerProtocol): def onConnect(self, request): self.request = request self.packets_received = 0 + self.socket_opened = time.time() try: # Sanitize and decode headers clean_headers = {} @@ -114,6 +115,12 @@ class WebSocketProtocol(WebSocketServerProtocol): else: logger.debug("WebSocket closed before handshake established") + def duration(self): + """ + Returns the time since the socket was opened + """ + return time.time() - self.socket_opened + class WebSocketFactory(WebSocketServerFactory): """