From 0ad7f1c2a2514bc2e034960aa5bf5b980925017e Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 2 Mar 2016 11:24:19 -0800 Subject: [PATCH] Add timeout 503 responses with configurable delay. --- daphne/cli.py | 8 ++++++ daphne/http_protocol.py | 57 ++++++++++++++++++++++++++++++++++++++--- daphne/server.py | 22 ++++++++++++++-- 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index dcfc8e1..3d8b1ab 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -40,6 +40,13 @@ class CommandLineInterface(object): help='How verbose to make the output', default=1, ) + self.parser.add_argument( + '-t', + '--http-timeout', + type=int, + help='How long to wait for worker server before timing out HTTP connections', + default=120, + ) self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', @@ -85,4 +92,5 @@ class CommandLineInterface(object): channel_layer=channel_layer, host=args.host, port=args.port, + http_timeout=args.http_timeout, ).run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 666ebf7..6fd1153 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -4,7 +4,6 @@ import logging import six import time -from twisted.python.compat import _PY3 from twisted.web import http from twisted.protocols.policies import ProtocolWrapper @@ -22,6 +21,25 @@ class WebRequest(http.Request): GET and POST out. """ + error_template = """ + + + %(title)s + + + +

%(title)s

+

%(body)s

+ + + + """.replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise + def __init__(self, *args, **kwargs): http.Request.__init__(self, *args, **kwargs) # Easy factory link @@ -157,11 +175,34 @@ class WebRequest(http.Request): "status": self.code, "method": self.method.decode("ascii"), "client": "%s:%s" % (self.client.host, self.client.port), - "time_taken": time.time() - self.request_start, + "time_taken": self.duration(), }) else: logger.debug("HTTP response chunk for %s", self.reply_channel) + def duration(self): + """ + Returns the time since the start of the request. + """ + return time.time() - self.request_start + + def basic_error(self, status, status_text, body): + """ + Responds with a server-level error page (very basic) + """ + self.serverResponse({ + "status": status, + "status_text": status_text, + "headers": [ + ("Content-Type", b"text/html; charset=utf-8"), + ], + "content": (self.error_template % { + "title": str(status) + " " + status_text.decode("ascii"), + "body": body, + }).encode("utf8"), + }) + + class HTTPProtocol(http.HTTPChannel): @@ -178,10 +219,11 @@ class HTTPFactory(http.HTTPFactory): protocol = HTTPProtocol - def __init__(self, channel_layer, action_logger=None): + def __init__(self, channel_layer, action_logger=None, timeout=120): http.HTTPFactory.__init__(self) self.channel_layer = channel_layer self.action_logger = action_logger + self.timeout = timeout # We track all sub-protocols for response channel mapping self.reply_protocols = {} # Make a factory for WebSocket protocols @@ -211,3 +253,12 @@ class HTTPFactory(http.HTTPFactory): """ if self.action_logger: self.action_logger(protocol, action, details) + + def check_timeouts(self): + """ + Runs through all HTTP protocol instances and times them out if they've + taken too long (and so their message is probably expired) + """ + for protocol in list(self.reply_protocols.values()): + if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: + protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") diff --git a/daphne/server.py b/daphne/server.py index 88ed54c..baf0301 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -8,17 +8,27 @@ logger = logging.getLogger(__name__) class Server(object): - def __init__(self, channel_layer, host="127.0.0.1", port=8000, signal_handlers=True, action_logger=None): + def __init__( + self, + channel_layer, + host="127.0.0.1", + port=8000, + signal_handlers=True, + action_logger=None, + http_timeout=120 + ): self.channel_layer = channel_layer self.host = host self.port = port self.signal_handlers = signal_handlers self.action_logger = action_logger + self.http_timeout = http_timeout def run(self): - self.factory = HTTPFactory(self.channel_layer, self.action_logger) + self.factory = HTTPFactory(self.channel_layer, self.action_logger, timeout=self.http_timeout) reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.callLater(0, self.backend_reader) + reactor.callLater(2, self.timeout_checker) reactor.run(installSignalHandlers=self.signal_handlers) def backend_reader(self): @@ -41,3 +51,11 @@ class Server(object): # Deal with the message self.factory.dispatch_reply(channel, message) reactor.callLater(delay, self.backend_reader) + + def timeout_checker(self): + """ + Called periodically to enforce timeout rules on HTTP connections + (but not WebSocket) + """ + self.factory.check_timeouts() + reactor.callLater(2, self.timeout_checker)