mirror of
https://github.com/django/daphne.git
synced 2024-11-24 17:03:42 +03:00
Fixed #152: Give ASGI apps a grace period after close before killing
Also adds a warning when they do not correctly exit before killing.
This commit is contained in:
parent
d46429247f
commit
678a97ec7f
|
@ -105,6 +105,12 @@ class CommandLineInterface(object):
|
||||||
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
|
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
|
||||||
default=30,
|
default=30,
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--application-close-timeout",
|
||||||
|
type=int,
|
||||||
|
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--ws-protocol",
|
"--ws-protocol",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
|
@ -201,6 +207,7 @@ class CommandLineInterface(object):
|
||||||
ping_timeout=args.ping_timeout,
|
ping_timeout=args.ping_timeout,
|
||||||
websocket_timeout=args.websocket_timeout,
|
websocket_timeout=args.websocket_timeout,
|
||||||
websocket_connect_timeout=args.websocket_connect_timeout,
|
websocket_connect_timeout=args.websocket_connect_timeout,
|
||||||
|
application_close_timeout=args.application_close_timeout,
|
||||||
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
||||||
ws_protocols=args.ws_protocols,
|
ws_protocols=args.ws_protocols,
|
||||||
root_path=args.root_path,
|
root_path=args.root_path,
|
||||||
|
|
|
@ -48,7 +48,7 @@ class WebRequest(http.Request):
|
||||||
self.server = self.channel.factory.server
|
self.server = self.channel.factory.server
|
||||||
self.application_queue = None
|
self.application_queue = None
|
||||||
self._response_started = False
|
self._response_started = False
|
||||||
self.server.add_protocol(self)
|
self.server.protocol_connected(self)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
@ -118,7 +118,7 @@ class WebRequest(http.Request):
|
||||||
protocol.dataReceived(data)
|
protocol.dataReceived(data)
|
||||||
# Remove our HTTP reply channel association
|
# Remove our HTTP reply channel association
|
||||||
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
|
||||||
self.server.discard_protocol(self)
|
self.server.protocol_disconnected(self)
|
||||||
# Resume the producer so we keep getting data, if it's available as a method
|
# Resume the producer so we keep getting data, if it's available as a method
|
||||||
self.channel._networkProducer.resumeProducing()
|
self.channel._networkProducer.resumeProducing()
|
||||||
|
|
||||||
|
@ -171,7 +171,7 @@ class WebRequest(http.Request):
|
||||||
self.send_disconnect()
|
self.send_disconnect()
|
||||||
logger.debug("HTTP disconnect for %s", self.client_addr)
|
logger.debug("HTTP disconnect for %s", self.client_addr)
|
||||||
http.Request.connectionLost(self, reason)
|
http.Request.connectionLost(self, reason)
|
||||||
self.server.discard_protocol(self)
|
self.server.protocol_disconnected(self)
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
"""
|
"""
|
||||||
|
@ -181,7 +181,7 @@ class WebRequest(http.Request):
|
||||||
self.send_disconnect()
|
self.send_disconnect()
|
||||||
logger.debug("HTTP close for %s", self.client_addr)
|
logger.debug("HTTP close for %s", self.client_addr)
|
||||||
http.Request.finish(self)
|
http.Request.finish(self)
|
||||||
self.server.discard_protocol(self)
|
self.server.protocol_disconnected(self)
|
||||||
|
|
||||||
### Server reply callbacks
|
### Server reply callbacks
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,9 @@ else:
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from concurrent.futures import CancelledError
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.endpoints import serverFromString
|
from twisted.internet.endpoints import serverFromString
|
||||||
|
@ -50,12 +52,11 @@ class Server(object):
|
||||||
proxy_forwarded_port_header=None,
|
proxy_forwarded_port_header=None,
|
||||||
verbosity=1,
|
verbosity=1,
|
||||||
websocket_handshake_timeout=5,
|
websocket_handshake_timeout=5,
|
||||||
|
application_close_timeout=10,
|
||||||
ready_callable=None,
|
ready_callable=None,
|
||||||
):
|
):
|
||||||
self.application = application
|
self.application = application
|
||||||
self.endpoints = endpoints or []
|
self.endpoints = endpoints or []
|
||||||
if not self.endpoints:
|
|
||||||
raise UserWarning("No endpoints. This server will not listen on anything.")
|
|
||||||
self.listeners = []
|
self.listeners = []
|
||||||
self.listening_addresses = []
|
self.listening_addresses = []
|
||||||
self.signal_handlers = signal_handlers
|
self.signal_handlers = signal_handlers
|
||||||
|
@ -68,16 +69,20 @@ class Server(object):
|
||||||
self.websocket_timeout = websocket_timeout
|
self.websocket_timeout = websocket_timeout
|
||||||
self.websocket_connect_timeout = websocket_connect_timeout
|
self.websocket_connect_timeout = websocket_connect_timeout
|
||||||
self.websocket_handshake_timeout = websocket_handshake_timeout
|
self.websocket_handshake_timeout = websocket_handshake_timeout
|
||||||
|
self.application_close_timeout = application_close_timeout
|
||||||
self.websocket_protocols = ws_protocols
|
self.websocket_protocols = ws_protocols
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
self.abort_start = False
|
self.abort_start = False
|
||||||
self.ready_callable = ready_callable
|
self.ready_callable = ready_callable
|
||||||
|
# Check our construction is actually sensible
|
||||||
|
if not self.endpoints:
|
||||||
|
logger.error("No endpoints. This server will not listen on anything.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# A set of current Twisted protocol instances to manage
|
# A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts
|
||||||
self.protocols = set()
|
self.connections = {}
|
||||||
self.application_instances = {}
|
|
||||||
# Make the factory
|
# Make the factory
|
||||||
self.http_factory = HTTPFactory(self)
|
self.http_factory = HTTPFactory(self)
|
||||||
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne")
|
self.ws_factory = WebSocketFactory(self, protocols=self.websocket_protocols, server="Daphne")
|
||||||
|
@ -150,18 +155,17 @@ class Server(object):
|
||||||
|
|
||||||
### Protocol handling
|
### Protocol handling
|
||||||
|
|
||||||
def add_protocol(self, protocol):
|
def protocol_connected(self, protocol):
|
||||||
if protocol in self.protocols:
|
"""
|
||||||
|
Adds a protocol as a current connection.
|
||||||
|
"""
|
||||||
|
if protocol in self.connections:
|
||||||
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
|
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
|
||||||
self.protocols.add(protocol)
|
self.connections[protocol] = {"connected": time.time()}
|
||||||
|
|
||||||
def discard_protocol(self, protocol):
|
def protocol_disconnected(self, protocol):
|
||||||
# Ensure it's not in the protocol-tracking set
|
# Set its disconnected time (the loops will come and clean it up)
|
||||||
self.protocols.discard(protocol)
|
self.connections[protocol]["disconnected"] = time.time()
|
||||||
# Make sure any application future that's running is cancelled
|
|
||||||
if protocol in self.application_instances:
|
|
||||||
self.application_instances[protocol].cancel()
|
|
||||||
del self.application_instances[protocol]
|
|
||||||
|
|
||||||
### Internal event/message handling
|
### Internal event/message handling
|
||||||
|
|
||||||
|
@ -173,12 +177,12 @@ class Server(object):
|
||||||
return you the application's input queue
|
return you the application's input queue
|
||||||
"""
|
"""
|
||||||
# Make sure the protocol has not had another application made for it
|
# Make sure the protocol has not had another application made for it
|
||||||
assert protocol not in self.application_instances
|
assert "application_instance" not in self.connections[protocol]
|
||||||
# Make an instance of the application
|
# Make an instance of the application
|
||||||
input_queue = asyncio.Queue()
|
input_queue = asyncio.Queue()
|
||||||
application_instance = self.application(scope=scope)
|
application_instance = self.application(scope=scope)
|
||||||
# Run it, and stash the future for later checking
|
# Run it, and stash the future for later checking
|
||||||
self.application_instances[protocol] = asyncio.ensure_future(application_instance(
|
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
|
||||||
receive=input_queue.get,
|
receive=input_queue.get,
|
||||||
send=lambda message: self.handle_reply(protocol, message),
|
send=lambda message: self.handle_reply(protocol, message),
|
||||||
), loop=asyncio.get_event_loop())
|
), loop=asyncio.get_event_loop())
|
||||||
|
@ -197,29 +201,50 @@ class Server(object):
|
||||||
Goes through the set of current application Futures and cleans up
|
Goes through the set of current application Futures and cleans up
|
||||||
any that are done/prints exceptions for any that errored.
|
any that are done/prints exceptions for any that errored.
|
||||||
"""
|
"""
|
||||||
for protocol, application_instance in list(self.application_instances.items()):
|
for protocol, details in list(self.connections.items()):
|
||||||
if application_instance.done():
|
disconnected = details.get("disconnected", None)
|
||||||
exception = application_instance.exception()
|
application_instance = details.get("application_instance", None)
|
||||||
if exception:
|
# First, see if the protocol disconnected and the app has taken
|
||||||
if isinstance(exception, KeyboardInterrupt):
|
# too long to close up
|
||||||
# Protocol is asking the server to exit (likely during test)
|
if disconnected and time.time() - disconnected > self.application_close_timeout:
|
||||||
self.stop()
|
if not application_instance.done():
|
||||||
else:
|
logger.warning(
|
||||||
logger.error(
|
"Application instance %r for connection %s took too long to shut down and was killed.",
|
||||||
"Exception inside application: {}\n{}{}".format(
|
application_instance,
|
||||||
|
repr(protocol),
|
||||||
|
)
|
||||||
|
application_instance.cancel()
|
||||||
|
# Then see if the app is done and we should reap it
|
||||||
|
if application_instance and application_instance.done():
|
||||||
|
try:
|
||||||
|
exception = application_instance.exception()
|
||||||
|
except CancelledError:
|
||||||
|
# Future cancellation. We can ignore this.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if exception:
|
||||||
|
if isinstance(exception, KeyboardInterrupt):
|
||||||
|
# Protocol is asking the server to exit (likely during test)
|
||||||
|
self.stop()
|
||||||
|
else:
|
||||||
|
exception_output = "{}\n{}{}".format(
|
||||||
exception,
|
exception,
|
||||||
"".join(traceback.format_tb(
|
"".join(traceback.format_tb(
|
||||||
exception.__traceback__,
|
exception.__traceback__,
|
||||||
)),
|
)),
|
||||||
" {}".format(exception),
|
" {}".format(exception),
|
||||||
)
|
)
|
||||||
)
|
logger.error(
|
||||||
protocol.handle_exception(exception)
|
"Exception inside application: %s",
|
||||||
try:
|
exception_output,
|
||||||
del self.application_instances[protocol]
|
)
|
||||||
except KeyError:
|
if not disconnected:
|
||||||
# The protocol might have already got here before us. That's fine.
|
protocol.handle_exception(exception)
|
||||||
pass
|
del self.connections[protocol]["application_instance"]
|
||||||
|
application_instance = None
|
||||||
|
# Check to see if protocol is closed and app is closed so we can remove it
|
||||||
|
if not application_instance and disconnected:
|
||||||
|
del self.connections[protocol]
|
||||||
reactor.callLater(1, self.application_checker)
|
reactor.callLater(1, self.application_checker)
|
||||||
|
|
||||||
def kill_all_applications(self):
|
def kill_all_applications(self):
|
||||||
|
@ -228,7 +253,8 @@ class Server(object):
|
||||||
"""
|
"""
|
||||||
# Send cancel to all coroutines
|
# Send cancel to all coroutines
|
||||||
wait_for = []
|
wait_for = []
|
||||||
for application_instance in self.application_instances.values():
|
for details in self.connections.values():
|
||||||
|
application_instance = details["application_instance"]
|
||||||
if not application_instance.done():
|
if not application_instance.done():
|
||||||
application_instance.cancel()
|
application_instance.cancel()
|
||||||
wait_for.append(application_instance)
|
wait_for.append(application_instance)
|
||||||
|
@ -243,7 +269,7 @@ class Server(object):
|
||||||
Called periodically to enforce timeout rules on all connections.
|
Called periodically to enforce timeout rules on all connections.
|
||||||
Also checks pings at the same time.
|
Also checks pings at the same time.
|
||||||
"""
|
"""
|
||||||
for protocol in list(self.protocols):
|
for protocol in list(self.connections.keys()):
|
||||||
protocol.check_timeouts()
|
protocol.check_timeouts()
|
||||||
reactor.callLater(2, self.timeout_checker)
|
reactor.callLater(2, self.timeout_checker)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
|
|
||||||
def onConnect(self, request):
|
def onConnect(self, request):
|
||||||
self.server = self.factory.server_class
|
self.server = self.factory.server_class
|
||||||
self.server.add_protocol(self)
|
self.server.protocol_connected(self)
|
||||||
self.request = request
|
self.request = request
|
||||||
self.protocol_to_accept = None
|
self.protocol_to_accept = None
|
||||||
self.socket_opened = time.time()
|
self.socket_opened = time.time()
|
||||||
|
@ -124,7 +124,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
"""
|
"""
|
||||||
Called when Twisted closes the socket.
|
Called when Twisted closes the socket.
|
||||||
"""
|
"""
|
||||||
self.server.discard_protocol(self)
|
self.server.protocol_disconnected(self)
|
||||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||||
if not self.muted:
|
if not self.muted:
|
||||||
self.application_queue.put_nowait({
|
self.application_queue.put_nowait({
|
||||||
|
@ -187,7 +187,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
"""
|
"""
|
||||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||||
del self.handshake_deferred
|
del self.handshake_deferred
|
||||||
self.server.discard_protocol(self)
|
self.server.protocol_disconnected(self)
|
||||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||||
self.server.log_action("websocket", "rejected", {
|
self.server.log_action("websocket", "rejected", {
|
||||||
"path": self.request.path,
|
"path": self.request.path,
|
||||||
|
@ -245,6 +245,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return id(self) == id(other)
|
return id(self) == id(other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<WebSocketProtocol client=%r path=%r>" % (self.client_addr, self.path)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketFactory(WebSocketServerFactory):
|
class WebSocketFactory(WebSocketServerFactory):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user