From 678a97ec7f216c3b69fc9ad950861ae37b19d85c Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 7 Feb 2018 11:44:13 -0800 Subject: [PATCH] Fixed #152: Give ASGI apps a grace period after close before killing Also adds a warning when they do not correctly exit before killing. --- daphne/cli.py | 7 +++ daphne/http_protocol.py | 8 ++-- daphne/server.py | 98 ++++++++++++++++++++++++++--------------- daphne/ws_protocol.py | 9 ++-- 4 files changed, 79 insertions(+), 43 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index 53e9053..8d3c531 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -105,6 +105,12 @@ class CommandLineInterface(object): help="The number of seconds before a WebSocket is closed if no response to a keepalive ping", default=30, ) + self.parser.add_argument( + "--application-close-timeout", + type=int, + help="The number of seconds an ASGI application has to exit after client disconnect before it is killed", + default=10, + ) self.parser.add_argument( "--ws-protocol", nargs="*", @@ -201,6 +207,7 @@ class CommandLineInterface(object): ping_timeout=args.ping_timeout, websocket_timeout=args.websocket_timeout, websocket_connect_timeout=args.websocket_connect_timeout, + application_close_timeout=args.application_close_timeout, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, ws_protocols=args.ws_protocols, root_path=args.root_path, diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 37e43eb..dcf8dec 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -48,7 +48,7 @@ class WebRequest(http.Request): self.server = self.channel.factory.server self.application_queue = None self._response_started = False - self.server.add_protocol(self) + self.server.protocol_connected(self) except Exception: logger.error(traceback.format_exc()) raise @@ -118,7 +118,7 @@ class WebRequest(http.Request): protocol.dataReceived(data) # Remove our HTTP reply channel association logger.debug("Upgraded connection %s to WebSocket", self.client_addr) - self.server.discard_protocol(self) + self.server.protocol_disconnected(self) # Resume the producer so we keep getting data, if it's available as a method self.channel._networkProducer.resumeProducing() @@ -171,7 +171,7 @@ class WebRequest(http.Request): self.send_disconnect() logger.debug("HTTP disconnect for %s", self.client_addr) http.Request.connectionLost(self, reason) - self.server.discard_protocol(self) + self.server.protocol_disconnected(self) def finish(self): """ @@ -181,7 +181,7 @@ class WebRequest(http.Request): self.send_disconnect() logger.debug("HTTP close for %s", self.client_addr) http.Request.finish(self) - self.server.discard_protocol(self) + self.server.protocol_disconnected(self) ### Server reply callbacks diff --git a/daphne/server.py b/daphne/server.py index ba61bfa..fa76bd7 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -18,7 +18,9 @@ else: import asyncio import logging +import time import traceback +from concurrent.futures import CancelledError from twisted.internet import defer, reactor from twisted.internet.endpoints import serverFromString @@ -50,12 +52,11 @@ class Server(object): proxy_forwarded_port_header=None, verbosity=1, websocket_handshake_timeout=5, + application_close_timeout=10, ready_callable=None, ): self.application = application self.endpoints = endpoints or [] - if not self.endpoints: - raise UserWarning("No endpoints. This server will not listen on anything.") self.listeners = [] self.listening_addresses = [] self.signal_handlers = signal_handlers @@ -68,16 +69,20 @@ class Server(object): self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout + self.application_close_timeout = application_close_timeout self.websocket_protocols = ws_protocols self.root_path = root_path self.verbosity = verbosity self.abort_start = False self.ready_callable = ready_callable + # Check our construction is actually sensible + if not self.endpoints: + logger.error("No endpoints. This server will not listen on anything.") + sys.exit(1) def run(self): - # A set of current Twisted protocol instances to manage - self.protocols = set() - self.application_instances = {} + # A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts + self.connections = {} # Make the factory self.http_factory = HTTPFactory(self) self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne") @@ -150,18 +155,17 @@ class Server(object): ### Protocol handling - def add_protocol(self, protocol): - if protocol in self.protocols: + def protocol_connected(self, protocol): + """ + Adds a protocol as a current connection. + """ + if protocol in self.connections: raise RuntimeError("Protocol %r was added to main list twice!" % protocol) - self.protocols.add(protocol) + self.connections[protocol] = {"connected": time.time()} - def discard_protocol(self, protocol): - # Ensure it's not in the protocol-tracking set - self.protocols.discard(protocol) - # Make sure any application future that's running is cancelled - if protocol in self.application_instances: - self.application_instances[protocol].cancel() - del self.application_instances[protocol] + def protocol_disconnected(self, protocol): + # Set its disconnected time (the loops will come and clean it up) + self.connections[protocol]["disconnected"] = time.time() ### Internal event/message handling @@ -173,12 +177,12 @@ class Server(object): return you the application's input queue """ # Make sure the protocol has not had another application made for it - assert protocol not in self.application_instances + assert "application_instance" not in self.connections[protocol] # Make an instance of the application input_queue = asyncio.Queue() application_instance = self.application(scope=scope) # Run it, and stash the future for later checking - self.application_instances[protocol] = asyncio.ensure_future(application_instance( + self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance( receive=input_queue.get, send=lambda message: self.handle_reply(protocol, message), ), loop=asyncio.get_event_loop()) @@ -197,29 +201,50 @@ class Server(object): Goes through the set of current application Futures and cleans up any that are done/prints exceptions for any that errored. """ - for protocol, application_instance in list(self.application_instances.items()): - if application_instance.done(): - exception = application_instance.exception() - if exception: - if isinstance(exception, KeyboardInterrupt): - # Protocol is asking the server to exit (likely during test) - self.stop() - else: - logger.error( - "Exception inside application: {}\n{}{}".format( + for protocol, details in list(self.connections.items()): + disconnected = details.get("disconnected", None) + application_instance = details.get("application_instance", None) + # First, see if the protocol disconnected and the app has taken + # too long to close up + if disconnected and time.time() - disconnected > self.application_close_timeout: + if not application_instance.done(): + logger.warning( + "Application instance %r for connection %s took too long to shut down and was killed.", + application_instance, + repr(protocol), + ) + application_instance.cancel() + # Then see if the app is done and we should reap it + if application_instance and application_instance.done(): + try: + exception = application_instance.exception() + except CancelledError: + # Future cancellation. We can ignore this. + pass + else: + if exception: + if isinstance(exception, KeyboardInterrupt): + # Protocol is asking the server to exit (likely during test) + self.stop() + else: + exception_output = "{}\n{}{}".format( exception, "".join(traceback.format_tb( exception.__traceback__, )), " {}".format(exception), ) - ) - protocol.handle_exception(exception) - try: - del self.application_instances[protocol] - except KeyError: - # The protocol might have already got here before us. That's fine. - pass + logger.error( + "Exception inside application: %s", + exception_output, + ) + if not disconnected: + protocol.handle_exception(exception) + del self.connections[protocol]["application_instance"] + application_instance = None + # Check to see if protocol is closed and app is closed so we can remove it + if not application_instance and disconnected: + del self.connections[protocol] reactor.callLater(1, self.application_checker) def kill_all_applications(self): @@ -228,7 +253,8 @@ class Server(object): """ # Send cancel to all coroutines wait_for = [] - for application_instance in self.application_instances.values(): + for details in self.connections.values(): + application_instance = details["application_instance"] if not application_instance.done(): application_instance.cancel() wait_for.append(application_instance) @@ -243,7 +269,7 @@ class Server(object): Called periodically to enforce timeout rules on all connections. Also checks pings at the same time. """ - for protocol in list(self.protocols): + for protocol in list(self.connections.keys()): protocol.check_timeouts() reactor.callLater(2, self.timeout_checker) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index ffb91eb..2792662 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -24,7 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol): def onConnect(self, request): self.server = self.factory.server_class - self.server.add_protocol(self) + self.server.protocol_connected(self) self.request = request self.protocol_to_accept = None self.socket_opened = time.time() @@ -124,7 +124,7 @@ class WebSocketProtocol(WebSocketServerProtocol): """ Called when Twisted closes the socket. """ - self.server.discard_protocol(self) + self.server.protocol_disconnected(self) logger.debug("WebSocket closed for %s", self.client_addr) if not self.muted: self.application_queue.put_nowait({ @@ -187,7 +187,7 @@ class WebSocketProtocol(WebSocketServerProtocol): """ self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) del self.handshake_deferred - self.server.discard_protocol(self) + self.server.protocol_disconnected(self) logger.debug("WebSocket %s rejected by application", self.client_addr) self.server.log_action("websocket", "rejected", { "path": self.request.path, @@ -245,6 +245,9 @@ class WebSocketProtocol(WebSocketServerProtocol): def __eq__(self, other): return id(self) == id(other) + def __repr__(self): + return "" % (self.client_addr, self.path) + class WebSocketFactory(WebSocketServerFactory): """