Update Daphne for new process-local channel style

This commit is contained in:
Andrew Godwin 2017-03-27 19:49:50 -07:00
parent bd9b8d0068
commit 3937489c4a
4 changed files with 83 additions and 51 deletions

View File

@ -101,10 +101,6 @@ class CommandLineInterface(object):
help='The number of seconds before a WeSocket is closed if no response to a keepalive ping', help='The number of seconds before a WeSocket is closed if no response to a keepalive ping',
default=30, default=30,
) )
self.parser.add_argument(
'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path',
)
self.parser.add_argument( self.parser.add_argument(
'--ws-protocol', '--ws-protocol',
nargs='*', nargs='*',
@ -126,6 +122,17 @@ class CommandLineInterface(object):
default=False, default=False,
action='store_true', action='store_true',
) )
self.parser.add_argument(
'--force-sync',
dest='force_sync',
action='store_true',
help='Force the server to use synchronous mode on its ASGI channel layer',
default=False,
)
self.parser.add_argument(
'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path',
)
self.server = None self.server = None
@ -206,5 +213,6 @@ class CommandLineInterface(object):
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
force_sync=args.force_sync,
) )
self.server.run() self.server.run()

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import logging import logging
import random
import six import six
import string
import time import time
import traceback import traceback
@ -51,7 +53,7 @@ class WebRequest(http.Request):
# Easy factory link # Easy factory link
self.factory = self.channel.factory self.factory = self.channel.factory
# Make a name for our reply channel # Make a name for our reply channel
self.reply_channel = self.factory.channel_layer.new_channel("http.response!") self.reply_channel = self.factory.make_send_channel()
# Tell factory we're that channel's client # Tell factory we're that channel's client
self.last_keepalive = time.time() self.last_keepalive = time.time()
self.factory.reply_protocols[self.reply_channel] = self self.factory.reply_protocols[self.reply_channel] = self
@ -300,10 +302,11 @@ class HTTPFactory(http.HTTPFactory):
routed appropriately. routed appropriately.
""" """
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None): def __init__(self, channel_layer, action_logger=None, send_channel=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.action_logger = action_logger self.action_logger = action_logger
self.send_channel = send_channel
self.timeout = timeout self.timeout = timeout
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout self.websocket_connect_timeout = websocket_connect_timeout
@ -327,17 +330,31 @@ class HTTPFactory(http.HTTPFactory):
Builds protocol instances. This override is used to ensure we use our Builds protocol instances. This override is used to ensure we use our
own Request object instead of the default. own Request object instead of the default.
""" """
protocol = http.HTTPFactory.buildProtocol(self, addr) try:
protocol.requestFactory = WebRequest protocol = http.HTTPFactory.buildProtocol(self, addr)
return protocol protocol.requestFactory = WebRequest
return protocol
except Exception as e:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
def make_send_channel(self):
"""
Makes a new send channel for a protocol with our process prefix.
"""
protocol_id = "".join(random.choice(string.ascii_letters) for i in range(10))
return self.send_channel + protocol_id
def reply_channels(self): def reply_channels(self):
return self.reply_protocols.keys() return self.reply_protocols.keys()
def dispatch_reply(self, channel, message): def dispatch_reply(self, channel, message):
if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest): if channel not in self.reply_protocols:
raise ValueError("Cannot dispatch message on channel %r (unknown)" % channel)
if isinstance(self.reply_protocols[channel], WebRequest):
self.reply_protocols[channel].serverResponse(message) self.reply_protocols[channel].serverResponse(message)
elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): elif isinstance(self.reply_protocols[channel], WebSocketProtocol):
# Switch depending on current socket state # Switch depending on current socket state
protocol = self.reply_protocols[channel] protocol = self.reply_protocols[channel]
# See if the message is valid # See if the message is valid
@ -376,7 +393,7 @@ class HTTPFactory(http.HTTPFactory):
else: else:
protocol.serverClose(code=closing_code) protocol.serverClose(code=closing_code)
else: else:
raise ValueError("Cannot dispatch message on channel %r" % channel) raise ValueError("Unknown protocol class")
def log_action(self, protocol, action, details): def log_action(self, protocol, action, details):
""" """

View File

@ -1,4 +1,6 @@
import logging import logging
import random
import string
import warnings import warnings
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
@ -25,13 +27,14 @@ class Server(object):
action_logger=None, action_logger=None,
http_timeout=120, http_timeout=120,
websocket_timeout=None, websocket_timeout=None,
websocket_connect_timeout=None, websocket_connect_timeout=20,
ping_interval=20, ping_interval=20,
ping_timeout=30, ping_timeout=30,
ws_protocols=None, ws_protocols=None,
root_path="", root_path="",
proxy_forwarded_address_header=None, proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None, proxy_forwarded_port_header=None,
force_sync=False,
verbosity=1 verbosity=1
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
@ -63,14 +66,22 @@ class Server(object):
# If they did not provide a websocket timeout, default it to the # If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not. # channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.websocket_connect_timeout = websocket_connect_timeout
self.ws_protocols = ws_protocols self.ws_protocols = ws_protocols
self.root_path = root_path self.root_path = root_path
self.force_sync = force_sync
self.verbosity = verbosity self.verbosity = verbosity
def run(self): def run(self):
# Create process-local channel prefixes
# TODO: Can we guarantee non-collision better?
process_id = "".join(random.choice(string.ascii_letters) for i in range(10))
self.send_channel = "daphne.response.%s!" % process_id
# Make the factory
self.factory = HTTPFactory( self.factory = HTTPFactory(
self.channel_layer, self.channel_layer,
self.action_logger, action_logger=self.action_logger,
send_channel=self.send_channel,
timeout=self.http_timeout, timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout, websocket_timeout=self.websocket_timeout,
websocket_connect_timeout=self.websocket_connect_timeout, websocket_connect_timeout=self.websocket_connect_timeout,
@ -93,8 +104,7 @@ class Server(object):
else: else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
# Disabled deliberately for the moment as it's worse performing if "twisted" in self.channel_layer.extensions and not self.force_sync:
if "twisted" in self.channel_layer.extensions and False:
logger.info("Using native Twisted mode on channel layer") logger.info("Using native Twisted mode on channel layer")
reactor.callLater(0, self.backend_reader_twisted) reactor.callLater(0, self.backend_reader_twisted)
else: else:
@ -115,28 +125,30 @@ class Server(object):
Runs as an-often-as-possible task with the reactor, unless there was Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay. no result previously in which case we add a small delay.
""" """
channels = self.factory.reply_channels() channels = [self.send_channel]
delay = 0.05 delay = 0
# Quit if reactor is stopping # Quit if reactor is stopping
if not reactor.running: if not reactor.running:
logger.debug("Backend reader quitting due to reactor stop") logger.debug("Backend reader quitting due to reactor stop")
return return
# Don't do anything if there's no channels to listen on # Try to receive a message
if channels: try:
delay = 0.01 channel, message = self.channel_layer.receive(channels, block=False)
try: except Exception as e:
channel, message = self.channel_layer.receive(channels, block=False) # Log the error and wait a bit to retry
except Exception as e: logger.error('Error trying to receive messages: %s' % e)
logger.error('Error at trying to receive messages: %s' % e) delay = 5.00
delay = 5.00 else:
if channel:
# Deal with the message
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
else: else:
if channel: # If there's no messages, idle a little bit.
delay = 0.00 delay = 0.05
# Deal with the message # We can't loop inside here as this is synchronous code.
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
reactor.callLater(delay, self.backend_reader_sync) reactor.callLater(delay, self.backend_reader_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -145,28 +157,23 @@ class Server(object):
Runs as an-often-as-possible task with the reactor, unless there was Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay. no result previously in which case we add a small delay.
""" """
channels = [self.send_channel]
while True: while True:
if not reactor.running: if not reactor.running:
logging.debug("Backend reader quitting due to reactor stop") logging.debug("Backend reader quitting due to reactor stop")
return return
channels = self.factory.reply_channels() try:
if channels: channel, message = yield self.channel_layer.receive_twisted(channels)
try: except Exception as e:
channel, message = yield self.channel_layer.receive_twisted(channels) logger.error('Error trying to receive messages: %s' % e)
except Exception as e: yield self.sleep(5.00)
logger.error('Error at trying to receive messages: %s' % e)
yield self.sleep(5.00)
else:
# Deal with the message
if channel:
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
else:
yield self.sleep(0.01)
else: else:
yield self.sleep(0.05) # Deal with the message
if channel:
try:
self.factory.dispatch_reply(channel, message)
except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
def sleep(self, delay): def sleep(self, delay):
d = defer.Deferred() d = defer.Deferred()

View File

@ -46,7 +46,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# TODO: get autobahn to provide it raw # TODO: get autobahn to provide it raw
query_string = urlencode(request.params, doseq=True).encode("ascii") query_string = urlencode(request.params, doseq=True).encode("ascii")
# Make sending channel # Make sending channel
self.reply_channel = self.channel_layer.new_channel("websocket.send!") self.reply_channel = self.main_factory.make_send_channel()
# Tell main factory about it # Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self self.main_factory.reply_protocols[self.reply_channel] = self
# Get client address if possible # Get client address if possible