Fixed #152: Give ASGI apps a grace period after close before killing

Also adds a warning when they do not correctly exit before killing.
This commit is contained in:
Andrew Godwin 2018-02-07 11:44:13 -08:00
parent d46429247f
commit 678a97ec7f
4 changed files with 79 additions and 43 deletions

View File

@ -105,6 +105,12 @@ class CommandLineInterface(object):
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping", help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
default=30, default=30,
) )
self.parser.add_argument(
"--application-close-timeout",
type=int,
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
default=10,
)
self.parser.add_argument( self.parser.add_argument(
"--ws-protocol", "--ws-protocol",
nargs="*", nargs="*",
@ -201,6 +207,7 @@ class CommandLineInterface(object):
ping_timeout=args.ping_timeout, ping_timeout=args.ping_timeout,
websocket_timeout=args.websocket_timeout, websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout, websocket_connect_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
ws_protocols=args.ws_protocols, ws_protocols=args.ws_protocols,
root_path=args.root_path, root_path=args.root_path,

View File

@ -48,7 +48,7 @@ class WebRequest(http.Request):
self.server = self.channel.factory.server self.server = self.channel.factory.server
self.application_queue = None self.application_queue = None
self._response_started = False self._response_started = False
self.server.add_protocol(self) self.server.protocol_connected(self)
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
@ -118,7 +118,7 @@ class WebRequest(http.Request):
protocol.dataReceived(data) protocol.dataReceived(data)
# Remove our HTTP reply channel association # Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr) logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.discard_protocol(self) self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method # Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing() self.channel._networkProducer.resumeProducing()
@ -171,7 +171,7 @@ class WebRequest(http.Request):
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr) logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason) http.Request.connectionLost(self, reason)
self.server.discard_protocol(self) self.server.protocol_disconnected(self)
def finish(self): def finish(self):
""" """
@ -181,7 +181,7 @@ class WebRequest(http.Request):
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr) logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self) http.Request.finish(self)
self.server.discard_protocol(self) self.server.protocol_disconnected(self)
### Server reply callbacks ### Server reply callbacks

View File

@ -18,7 +18,9 @@ else:
import asyncio import asyncio
import logging import logging
import time
import traceback import traceback
from concurrent.futures import CancelledError
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
@ -50,12 +52,11 @@ class Server(object):
proxy_forwarded_port_header=None, proxy_forwarded_port_header=None,
verbosity=1, verbosity=1,
websocket_handshake_timeout=5, websocket_handshake_timeout=5,
application_close_timeout=10,
ready_callable=None, ready_callable=None,
): ):
self.application = application self.application = application
self.endpoints = endpoints or [] self.endpoints = endpoints or []
if not self.endpoints:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = [] self.listeners = []
self.listening_addresses = [] self.listening_addresses = []
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
@ -68,16 +69,20 @@ class Server(object):
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout self.websocket_handshake_timeout = websocket_handshake_timeout
self.application_close_timeout = application_close_timeout
self.websocket_protocols = ws_protocols self.websocket_protocols = ws_protocols
self.root_path = root_path self.root_path = root_path
self.verbosity = verbosity self.verbosity = verbosity
self.abort_start = False self.abort_start = False
self.ready_callable = ready_callable self.ready_callable = ready_callable
# Check our construction is actually sensible
if not self.endpoints:
logger.error("No endpoints. This server will not listen on anything.")
sys.exit(1)
def run(self): def run(self):
# A set of current Twisted protocol instances to manage # A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts
self.protocols = set() self.connections = {}
self.application_instances = {}
# Make the factory # Make the factory
self.http_factory = HTTPFactory(self) self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne") self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne")
@ -150,18 +155,17 @@ class Server(object):
### Protocol handling ### Protocol handling
def add_protocol(self, protocol): def protocol_connected(self, protocol):
if protocol in self.protocols: """
Adds a protocol as a current connection.
"""
if protocol in self.connections:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol) raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.protocols.add(protocol) self.connections[protocol] = {"connected": time.time()}
def discard_protocol(self, protocol): def protocol_disconnected(self, protocol):
# Ensure it's not in the protocol-tracking set # Set its disconnected time (the loops will come and clean it up)
self.protocols.discard(protocol) self.connections[protocol]["disconnected"] = time.time()
# Make sure any application future that's running is cancelled
if protocol in self.application_instances:
self.application_instances[protocol].cancel()
del self.application_instances[protocol]
### Internal event/message handling ### Internal event/message handling
@ -173,12 +177,12 @@ class Server(object):
return you the application's input queue return you the application's input queue
""" """
# Make sure the protocol has not had another application made for it # Make sure the protocol has not had another application made for it
assert protocol not in self.application_instances assert "application_instance" not in self.connections[protocol]
# Make an instance of the application # Make an instance of the application
input_queue = asyncio.Queue() input_queue = asyncio.Queue()
application_instance = self.application(scope=scope) application_instance = self.application(scope=scope)
# Run it, and stash the future for later checking # Run it, and stash the future for later checking
self.application_instances[protocol] = asyncio.ensure_future(application_instance( self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
receive=input_queue.get, receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message), send=lambda message: self.handle_reply(protocol, message),
), loop=asyncio.get_event_loop()) ), loop=asyncio.get_event_loop())
@ -197,29 +201,50 @@ class Server(object):
Goes through the set of current application Futures and cleans up Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored. any that are done/prints exceptions for any that errored.
""" """
for protocol, application_instance in list(self.application_instances.items()): for protocol, details in list(self.connections.items()):
if application_instance.done(): disconnected = details.get("disconnected", None)
exception = application_instance.exception() application_instance = details.get("application_instance", None)
if exception: # First, see if the protocol disconnected and the app has taken
if isinstance(exception, KeyboardInterrupt): # too long to close up
# Protocol is asking the server to exit (likely during test) if disconnected and time.time() - disconnected > self.application_close_timeout:
self.stop() if not application_instance.done():
else: logger.warning(
logger.error( "Application instance %r for connection %s took too long to shut down and was killed.",
"Exception inside application: {}\n{}{}".format( application_instance,
repr(protocol),
)
application_instance.cancel()
# Then see if the app is done and we should reap it
if application_instance and application_instance.done():
try:
exception = application_instance.exception()
except CancelledError:
# Future cancellation. We can ignore this.
pass
else:
if exception:
if isinstance(exception, KeyboardInterrupt):
# Protocol is asking the server to exit (likely during test)
self.stop()
else:
exception_output = "{}\n{}{}".format(
exception, exception,
"".join(traceback.format_tb( "".join(traceback.format_tb(
exception.__traceback__, exception.__traceback__,
)), )),
" {}".format(exception), " {}".format(exception),
) )
) logger.error(
protocol.handle_exception(exception) "Exception inside application: %s",
try: exception_output,
del self.application_instances[protocol] )
except KeyError: if not disconnected:
# The protocol might have already got here before us. That's fine. protocol.handle_exception(exception)
pass del self.connections[protocol]["application_instance"]
application_instance = None
# Check to see if protocol is closed and app is closed so we can remove it
if not application_instance and disconnected:
del self.connections[protocol]
reactor.callLater(1, self.application_checker) reactor.callLater(1, self.application_checker)
def kill_all_applications(self): def kill_all_applications(self):
@ -228,7 +253,8 @@ class Server(object):
""" """
# Send cancel to all coroutines # Send cancel to all coroutines
wait_for = [] wait_for = []
for application_instance in self.application_instances.values(): for details in self.connections.values():
application_instance = details["application_instance"]
if not application_instance.done(): if not application_instance.done():
application_instance.cancel() application_instance.cancel()
wait_for.append(application_instance) wait_for.append(application_instance)
@ -243,7 +269,7 @@ class Server(object):
Called periodically to enforce timeout rules on all connections. Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time. Also checks pings at the same time.
""" """
for protocol in list(self.protocols): for protocol in list(self.connections.keys()):
protocol.check_timeouts() protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)

View File

@ -24,7 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onConnect(self, request): def onConnect(self, request):
self.server = self.factory.server_class self.server = self.factory.server_class
self.server.add_protocol(self) self.server.protocol_connected(self)
self.request = request self.request = request
self.protocol_to_accept = None self.protocol_to_accept = None
self.socket_opened = time.time() self.socket_opened = time.time()
@ -124,7 +124,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
Called when Twisted closes the socket. Called when Twisted closes the socket.
""" """
self.server.discard_protocol(self) self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr) logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted: if not self.muted:
self.application_queue.put_nowait({ self.application_queue.put_nowait({
@ -187,7 +187,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
del self.handshake_deferred del self.handshake_deferred
self.server.discard_protocol(self) self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr) logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", { self.server.log_action("websocket", "rejected", {
"path": self.request.path, "path": self.request.path,
@ -245,6 +245,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
def __eq__(self, other): def __eq__(self, other):
return id(self) == id(other) return id(self) == id(other)
def __repr__(self):
return "<WebSocketProtocol client=%r path=%r>" % (self.client_addr, self.path)
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):
""" """