Fixed #152: Give ASGI apps a grace period after close before killing

Also adds a warning when they do not correctly exit before killing.
This commit is contained in:
Andrew Godwin 2018-02-07 11:44:13 -08:00
parent d46429247f
commit 678a97ec7f
4 changed files with 79 additions and 43 deletions

View File

@ -105,6 +105,12 @@ class CommandLineInterface(object):
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
default=30,
)
self.parser.add_argument(
"--application-close-timeout",
type=int,
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
default=10,
)
self.parser.add_argument(
"--ws-protocol",
nargs="*",
@ -201,6 +207,7 @@ class CommandLineInterface(object):
ping_timeout=args.ping_timeout,
websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
ws_protocols=args.ws_protocols,
root_path=args.root_path,

View File

@ -48,7 +48,7 @@ class WebRequest(http.Request):
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.server.add_protocol(self)
self.server.protocol_connected(self)
except Exception:
logger.error(traceback.format_exc())
raise
@ -118,7 +118,7 @@ class WebRequest(http.Request):
protocol.dataReceived(data)
# Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.discard_protocol(self)
self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing()
@ -171,7 +171,7 @@ class WebRequest(http.Request):
self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason)
self.server.discard_protocol(self)
self.server.protocol_disconnected(self)
def finish(self):
"""
@ -181,7 +181,7 @@ class WebRequest(http.Request):
self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self)
self.server.discard_protocol(self)
self.server.protocol_disconnected(self)
### Server reply callbacks

View File

@ -18,7 +18,9 @@ else:
import asyncio
import logging
import time
import traceback
from concurrent.futures import CancelledError
from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
@ -50,12 +52,11 @@ class Server(object):
proxy_forwarded_port_header=None,
verbosity=1,
websocket_handshake_timeout=5,
application_close_timeout=10,
ready_callable=None,
):
self.application = application
self.endpoints = endpoints or []
if not self.endpoints:
raise UserWarning("No endpoints. This server will not listen on anything.")
self.listeners = []
self.listening_addresses = []
self.signal_handlers = signal_handlers
@ -68,16 +69,20 @@ class Server(object):
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout
self.application_close_timeout = application_close_timeout
self.websocket_protocols = ws_protocols
self.root_path = root_path
self.verbosity = verbosity
self.abort_start = False
self.ready_callable = ready_callable
# Check our construction is actually sensible
if not self.endpoints:
logger.error("No endpoints. This server will not listen on anything.")
sys.exit(1)
def run(self):
# A set of current Twisted protocol instances to manage
self.protocols = set()
self.application_instances = {}
# A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne")
@ -150,18 +155,17 @@ class Server(object):
### Protocol handling
def add_protocol(self, protocol):
if protocol in self.protocols:
def protocol_connected(self, protocol):
"""
Adds a protocol as a current connection.
"""
if protocol in self.connections:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.protocols.add(protocol)
self.connections[protocol] = {"connected": time.time()}
def discard_protocol(self, protocol):
# Ensure it's not in the protocol-tracking set
self.protocols.discard(protocol)
# Make sure any application future that's running is cancelled
if protocol in self.application_instances:
self.application_instances[protocol].cancel()
del self.application_instances[protocol]
def protocol_disconnected(self, protocol):
# Set its disconnected time (the loops will come and clean it up)
self.connections[protocol]["disconnected"] = time.time()
### Internal event/message handling
@ -173,12 +177,12 @@ class Server(object):
return you the application's input queue
"""
# Make sure the protocol has not had another application made for it
assert protocol not in self.application_instances
assert "application_instance" not in self.connections[protocol]
# Make an instance of the application
input_queue = asyncio.Queue()
application_instance = self.application(scope=scope)
# Run it, and stash the future for later checking
self.application_instances[protocol] = asyncio.ensure_future(application_instance(
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
), loop=asyncio.get_event_loop())
@ -197,29 +201,50 @@ class Server(object):
Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored.
"""
for protocol, application_instance in list(self.application_instances.items()):
if application_instance.done():
for protocol, details in list(self.connections.items()):
disconnected = details.get("disconnected", None)
application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken
# too long to close up
if disconnected and time.time() - disconnected > self.application_close_timeout:
if not application_instance.done():
logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.",
application_instance,
repr(protocol),
)
application_instance.cancel()
# Then see if the app is done and we should reap it
if application_instance and application_instance.done():
try:
exception = application_instance.exception()
except CancelledError:
# Future cancellation. We can ignore this.
pass
else:
if exception:
if isinstance(exception, KeyboardInterrupt):
# Protocol is asking the server to exit (likely during test)
self.stop()
else:
logger.error(
"Exception inside application: {}\n{}{}".format(
exception_output = "{}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
" {}".format(exception),
)
logger.error(
"Exception inside application: %s",
exception_output,
)
if not disconnected:
protocol.handle_exception(exception)
try:
del self.application_instances[protocol]
except KeyError:
# The protocol might have already got here before us. That's fine.
pass
del self.connections[protocol]["application_instance"]
application_instance = None
# Check to see if protocol is closed and app is closed so we can remove it
if not application_instance and disconnected:
del self.connections[protocol]
reactor.callLater(1, self.application_checker)
def kill_all_applications(self):
@ -228,7 +253,8 @@ class Server(object):
"""
# Send cancel to all coroutines
wait_for = []
for application_instance in self.application_instances.values():
for details in self.connections.values():
application_instance = details["application_instance"]
if not application_instance.done():
application_instance.cancel()
wait_for.append(application_instance)
@ -243,7 +269,7 @@ class Server(object):
Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time.
"""
for protocol in list(self.protocols):
for protocol in list(self.connections.keys()):
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker)

View File

@ -24,7 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onConnect(self, request):
self.server = self.factory.server_class
self.server.add_protocol(self)
self.server.protocol_connected(self)
self.request = request
self.protocol_to_accept = None
self.socket_opened = time.time()
@ -124,7 +124,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
Called when Twisted closes the socket.
"""
self.server.discard_protocol(self)
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted:
self.application_queue.put_nowait({
@ -187,7 +187,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
del self.handshake_deferred
self.server.discard_protocol(self)
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", {
"path": self.request.path,
@ -245,6 +245,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
def __eq__(self, other):
return id(self) == id(other)
def __repr__(self):
return "<WebSocketProtocol client=%r path=%r>" % (self.client_addr, self.path)
class WebSocketFactory(WebSocketServerFactory):
"""