add support for websocket protocols

This commit is contained in:
Flavio Curella 2016-04-07 09:49:55 -05:00
parent 7b13995dee
commit 1d1b397aa2
3 changed files with 9 additions and 2 deletions

View File

@ -222,7 +222,7 @@ class HTTPFactory(http.HTTPFactory):
protocol = HTTPProtocol protocol = HTTPProtocol
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ws_protocols=None):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.action_logger = action_logger self.action_logger = action_logger
@ -232,7 +232,7 @@ class HTTPFactory(http.HTTPFactory):
# We track all sub-protocols for response channel mapping # We track all sub-protocols for response channel mapping
self.reply_protocols = {} self.reply_protocols = {}
# Make a factory for WebSocket protocols # Make a factory for WebSocket protocols
self.ws_factory = WebSocketFactory(self) self.ws_factory = WebSocketFactory(self, protocols=ws_protocols)
self.ws_factory.protocol = WebSocketProtocol self.ws_factory.protocol = WebSocketProtocol
self.ws_factory.reply_protocols = self.reply_protocols self.ws_factory.reply_protocols = self.reply_protocols

View File

@ -19,6 +19,7 @@ class Server(object):
http_timeout=120, http_timeout=120,
websocket_timeout=None, websocket_timeout=None,
ping_interval=20, ping_interval=20,
ws_protocols=None,
): ):
self.channel_layer = channel_layer self.channel_layer = channel_layer
self.host = host self.host = host
@ -31,6 +32,7 @@ class Server(object):
# If they did not provide a websocket timeout, default it to the # If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not. # channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.ws_protocols = ws_protocols
def run(self): def run(self):
self.factory = HTTPFactory( self.factory = HTTPFactory(
@ -39,6 +41,7 @@ class Server(object):
timeout=self.http_timeout, timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout, websocket_timeout=self.websocket_timeout,
ping_interval=self.ping_interval, ping_interval=self.ping_interval,
ws_protocols=self.ws_protocols,
) )
if self.unix_socket: if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory) reactor.listenUNIX(self.unix_socket, self.factory)

View File

@ -57,6 +57,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
ws_protocol = clean_headers.get('sec-websocket-protocol')
if ws_protocol and ws_protocol in self.factory.protocols:
return ws_protocol
def onOpen(self): def onOpen(self):
# Send news that this channel is open # Send news that this channel is open
logger.debug("WebSocket open for %s", self.reply_channel) logger.debug("WebSocket open for %s", self.reply_channel)