Add timeout to WebSockets as per ASGI spec

This commit is contained in:
Andrew Godwin 2016-03-28 03:28:15 -07:00
parent 813ef1c27c
commit 037f129117
3 changed files with 25 additions and 5 deletions

View File

@ -219,11 +219,12 @@ class HTTPFactory(http.HTTPFactory):
protocol = HTTPProtocol protocol = HTTPProtocol
def __init__(self, channel_layer, action_logger=None, timeout=120): def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.action_logger = action_logger self.action_logger = action_logger
self.timeout = timeout self.timeout = timeout
self.websocket_timeout = websocket_timeout
# We track all sub-protocols for response channel mapping # We track all sub-protocols for response channel mapping
self.reply_protocols = {} self.reply_protocols = {}
# Make a factory for WebSocket protocols # Make a factory for WebSocket protocols
@ -260,5 +261,9 @@ class HTTPFactory(http.HTTPFactory):
taken too long (and so their message is probably expired) taken too long (and so their message is probably expired)
""" """
for protocol in list(self.reply_protocols.values()): for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout: if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking
elif isinstance(protocol, WebSocketProtocol) and protocol.duration() > self.websocket_timeout:
protocol.serverClose()

View File

@ -15,7 +15,8 @@ class Server(object):
port=8000, port=8000,
signal_handlers=True, signal_handlers=True,
action_logger=None, action_logger=None,
http_timeout=120 http_timeout=120,
websocket_timeout=None,
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.host = host self.host = host
@ -23,9 +24,17 @@ class Server(object):
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
self.action_logger = action_logger self.action_logger = action_logger
self.http_timeout = http_timeout self.http_timeout = http_timeout
# If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
def run(self): def run(self):
self.factory = HTTPFactory(self.channel_layer, self.action_logger, timeout=self.http_timeout) self.factory = HTTPFactory(
self.channel_layer,
self.action_logger,
timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout,
)
reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.listenTCP(self.port, self.factory, interface=self.host)
reactor.callLater(0, self.backend_reader) reactor.callLater(0, self.backend_reader)
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
@ -54,8 +63,7 @@ class Server(object):
def timeout_checker(self): def timeout_checker(self):
""" """
Called periodically to enforce timeout rules on HTTP connections Called periodically to enforce timeout rules on all connections
(but not WebSocket)
""" """
self.factory.check_timeouts() self.factory.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)

View File

@ -23,6 +23,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onConnect(self, request): def onConnect(self, request):
self.request = request self.request = request
self.packets_received = 0 self.packets_received = 0
self.socket_opened = time.time()
try: try:
# Sanitize and decode headers # Sanitize and decode headers
clean_headers = {} clean_headers = {}
@ -114,6 +115,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
else: else:
logger.debug("WebSocket closed before handshake established") logger.debug("WebSocket closed before handshake established")
def duration(self):
"""
Returns the time since the socket was opened
"""
return time.time() - self.socket_opened
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):
""" """