mirror of
https://github.com/django/daphne.git
synced 2024-11-24 17:03:42 +03:00
Fixed #152: Give ASGI apps a grace period after close before killing
Also adds a warning when they do not correctly exit before killing.
This commit is contained in:
parent
d46429247f
commit
678a97ec7f
|
@ -105,6 +105,12 @@ class CommandLineInterface(object):
|
|||
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
|
||||
default=30,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--application-close-timeout",
|
||||
type=int,
|
||||
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
|
||||
default=10,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ws-protocol",
|
||||
nargs="*",
|
||||
|
@ -201,6 +207,7 @@ class CommandLineInterface(object):
|
|||
ping_timeout=args.ping_timeout,
|
||||
websocket_timeout=args.websocket_timeout,
|
||||
websocket_connect_timeout=args.websocket_connect_timeout,
|
||||
application_close_timeout=args.application_close_timeout,
|
||||
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
||||
ws_protocols=args.ws_protocols,
|
||||
root_path=args.root_path,
|
||||
|
|
|
@ -48,7 +48,7 @@ class WebRequest(http.Request):
|
|||
self.server = self.channel.factory.server
|
||||
self.application_queue = None
|
||||
self._response_started = False
|
||||
self.server.add_protocol(self)
|
||||
self.server.protocol_connected(self)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
@ -118,7 +118,7 @@ class WebRequest(http.Request):
|
|||
protocol.dataReceived(data)
|
||||
# Remove our HTTP reply channel association
|
||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||
self.server.discard_protocol(self)
|
||||
self.server.protocol_disconnected(self)
|
||||
# Resume the producer so we keep getting data, if it's available as a method
|
||||
self.channel._networkProducer.resumeProducing()
|
||||
|
||||
|
@ -171,7 +171,7 @@ class WebRequest(http.Request):
|
|||
self.send_disconnect()
|
||||
logger.debug("HTTP disconnect for %s", self.client_addr)
|
||||
http.Request.connectionLost(self, reason)
|
||||
self.server.discard_protocol(self)
|
||||
self.server.protocol_disconnected(self)
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
|
@ -181,7 +181,7 @@ class WebRequest(http.Request):
|
|||
self.send_disconnect()
|
||||
logger.debug("HTTP close for %s", self.client_addr)
|
||||
http.Request.finish(self)
|
||||
self.server.discard_protocol(self)
|
||||
self.server.protocol_disconnected(self)
|
||||
|
||||
### Server reply callbacks
|
||||
|
||||
|
|
|
@ -18,7 +18,9 @@ else:
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import CancelledError
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
|
@ -50,12 +52,11 @@ class Server(object):
|
|||
proxy_forwarded_port_header=None,
|
||||
verbosity=1,
|
||||
websocket_handshake_timeout=5,
|
||||
application_close_timeout=10,
|
||||
ready_callable=None,
|
||||
):
|
||||
self.application = application
|
||||
self.endpoints = endpoints or []
|
||||
if not self.endpoints:
|
||||
raise UserWarning("No endpoints. This server will not listen on anything.")
|
||||
self.listeners = []
|
||||
self.listening_addresses = []
|
||||
self.signal_handlers = signal_handlers
|
||||
|
@ -68,16 +69,20 @@ class Server(object):
|
|||
self.websocket_timeout = websocket_timeout
|
||||
self.websocket_connect_timeout = websocket_connect_timeout
|
||||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||
self.application_close_timeout = application_close_timeout
|
||||
self.websocket_protocols = ws_protocols
|
||||
self.root_path = root_path
|
||||
self.verbosity = verbosity
|
||||
self.abort_start = False
|
||||
self.ready_callable = ready_callable
|
||||
# Check our construction is actually sensible
|
||||
if not self.endpoints:
|
||||
logger.error("No endpoints. This server will not listen on anything.")
|
||||
sys.exit(1)
|
||||
|
||||
def run(self):
|
||||
# A set of current Twisted protocol instances to manage
|
||||
self.protocols = set()
|
||||
self.application_instances = {}
|
||||
# A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts
|
||||
self.connections = {}
|
||||
# Make the factory
|
||||
self.http_factory = HTTPFactory(self)
|
||||
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne")
|
||||
|
@ -150,18 +155,17 @@ class Server(object):
|
|||
|
||||
### Protocol handling
|
||||
|
||||
def add_protocol(self, protocol):
|
||||
if protocol in self.protocols:
|
||||
def protocol_connected(self, protocol):
|
||||
"""
|
||||
Adds a protocol as a current connection.
|
||||
"""
|
||||
if protocol in self.connections:
|
||||
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
|
||||
self.protocols.add(protocol)
|
||||
self.connections[protocol] = {"connected": time.time()}
|
||||
|
||||
def discard_protocol(self, protocol):
|
||||
# Ensure it's not in the protocol-tracking set
|
||||
self.protocols.discard(protocol)
|
||||
# Make sure any application future that's running is cancelled
|
||||
if protocol in self.application_instances:
|
||||
self.application_instances[protocol].cancel()
|
||||
del self.application_instances[protocol]
|
||||
def protocol_disconnected(self, protocol):
|
||||
# Set its disconnected time (the loops will come and clean it up)
|
||||
self.connections[protocol]["disconnected"] = time.time()
|
||||
|
||||
### Internal event/message handling
|
||||
|
||||
|
@ -173,12 +177,12 @@ class Server(object):
|
|||
return you the application's input queue
|
||||
"""
|
||||
# Make sure the protocol has not had another application made for it
|
||||
assert protocol not in self.application_instances
|
||||
assert "application_instance" not in self.connections[protocol]
|
||||
# Make an instance of the application
|
||||
input_queue = asyncio.Queue()
|
||||
application_instance = self.application(scope=scope)
|
||||
# Run it, and stash the future for later checking
|
||||
self.application_instances[protocol] = asyncio.ensure_future(application_instance(
|
||||
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
|
||||
receive=input_queue.get,
|
||||
send=lambda message: self.handle_reply(protocol, message),
|
||||
), loop=asyncio.get_event_loop())
|
||||
|
@ -197,29 +201,50 @@ class Server(object):
|
|||
Goes through the set of current application Futures and cleans up
|
||||
any that are done/prints exceptions for any that errored.
|
||||
"""
|
||||
for protocol, application_instance in list(self.application_instances.items()):
|
||||
if application_instance.done():
|
||||
exception = application_instance.exception()
|
||||
if exception:
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
# Protocol is asking the server to exit (likely during test)
|
||||
self.stop()
|
||||
else:
|
||||
logger.error(
|
||||
"Exception inside application: {}\n{}{}".format(
|
||||
for protocol, details in list(self.connections.items()):
|
||||
disconnected = details.get("disconnected", None)
|
||||
application_instance = details.get("application_instance", None)
|
||||
# First, see if the protocol disconnected and the app has taken
|
||||
# too long to close up
|
||||
if disconnected and time.time() - disconnected > self.application_close_timeout:
|
||||
if not application_instance.done():
|
||||
logger.warning(
|
||||
"Application instance %r for connection %s took too long to shut down and was killed.",
|
||||
application_instance,
|
||||
repr(protocol),
|
||||
)
|
||||
application_instance.cancel()
|
||||
# Then see if the app is done and we should reap it
|
||||
if application_instance and application_instance.done():
|
||||
try:
|
||||
exception = application_instance.exception()
|
||||
except CancelledError:
|
||||
# Future cancellation. We can ignore this.
|
||||
pass
|
||||
else:
|
||||
if exception:
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
# Protocol is asking the server to exit (likely during test)
|
||||
self.stop()
|
||||
else:
|
||||
exception_output = "{}\n{}{}".format(
|
||||
exception,
|
||||
"".join(traceback.format_tb(
|
||||
exception.__traceback__,
|
||||
)),
|
||||
" {}".format(exception),
|
||||
)
|
||||
)
|
||||
protocol.handle_exception(exception)
|
||||
try:
|
||||
del self.application_instances[protocol]
|
||||
except KeyError:
|
||||
# The protocol might have already got here before us. That's fine.
|
||||
pass
|
||||
logger.error(
|
||||
"Exception inside application: %s",
|
||||
exception_output,
|
||||
)
|
||||
if not disconnected:
|
||||
protocol.handle_exception(exception)
|
||||
del self.connections[protocol]["application_instance"]
|
||||
application_instance = None
|
||||
# Check to see if protocol is closed and app is closed so we can remove it
|
||||
if not application_instance and disconnected:
|
||||
del self.connections[protocol]
|
||||
reactor.callLater(1, self.application_checker)
|
||||
|
||||
def kill_all_applications(self):
|
||||
|
@ -228,7 +253,8 @@ class Server(object):
|
|||
"""
|
||||
# Send cancel to all coroutines
|
||||
wait_for = []
|
||||
for application_instance in self.application_instances.values():
|
||||
for details in self.connections.values():
|
||||
application_instance = details["application_instance"]
|
||||
if not application_instance.done():
|
||||
application_instance.cancel()
|
||||
wait_for.append(application_instance)
|
||||
|
@ -243,7 +269,7 @@ class Server(object):
|
|||
Called periodically to enforce timeout rules on all connections.
|
||||
Also checks pings at the same time.
|
||||
"""
|
||||
for protocol in list(self.protocols):
|
||||
for protocol in list(self.connections.keys()):
|
||||
protocol.check_timeouts()
|
||||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
|
||||
def onConnect(self, request):
|
||||
self.server = self.factory.server_class
|
||||
self.server.add_protocol(self)
|
||||
self.server.protocol_connected(self)
|
||||
self.request = request
|
||||
self.protocol_to_accept = None
|
||||
self.socket_opened = time.time()
|
||||
|
@ -124,7 +124,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"""
|
||||
Called when Twisted closes the socket.
|
||||
"""
|
||||
self.server.discard_protocol(self)
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||
if not self.muted:
|
||||
self.application_queue.put_nowait({
|
||||
|
@ -187,7 +187,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"""
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||
del self.handshake_deferred
|
||||
self.server.discard_protocol(self)
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action("websocket", "rejected", {
|
||||
"path": self.request.path,
|
||||
|
@ -245,6 +245,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
def __eq__(self, other):
|
||||
return id(self) == id(other)
|
||||
|
||||
def __repr__(self):
|
||||
return "<WebSocketProtocol client=%r path=%r>" % (self.client_addr, self.path)
|
||||
|
||||
|
||||
class WebSocketFactory(WebSocketServerFactory):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user