diff --git a/daphne/cli.py b/daphne/cli.py index 8763c87..708f7f4 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -70,6 +70,13 @@ class CommandLineInterface(object): 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', ) + self.parser.add_argument( + '--ws-protocol', + nargs='*', + dest='ws_protocols', + help='The WebSocket protocols you wish to support', + default=None, + ) @classmethod def entrypoint(cls): @@ -123,4 +130,5 @@ class CommandLineInterface(object): http_timeout=args.http_timeout, ping_interval=args.ping_interval, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, + ws_protocols=args.ws_protocols, ).run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 5e596e8..0fa544b 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -222,7 +222,7 @@ class HTTPFactory(http.HTTPFactory): protocol = HTTPProtocol - def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ws_protocols=None): http.HTTPFactory.__init__(self) self.channel_layer = channel_layer self.action_logger = action_logger @@ -232,7 +232,7 @@ class HTTPFactory(http.HTTPFactory): # We track all sub-protocols for response channel mapping self.reply_protocols = {} # Make a factory for WebSocket protocols - self.ws_factory = WebSocketFactory(self) + self.ws_factory = WebSocketFactory(self, protocols=ws_protocols) self.ws_factory.protocol = WebSocketProtocol self.ws_factory.reply_protocols = self.reply_protocols diff --git a/daphne/server.py b/daphne/server.py index 90bc6f2..2bb2622 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -19,6 +19,7 @@ class Server(object): http_timeout=120, websocket_timeout=None, ping_interval=20, + ws_protocols=None, ): self.channel_layer = channel_layer self.host = host @@ -31,6 +32,7 @@ class Server(object): # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) + self.ws_protocols = ws_protocols def run(self): self.factory = HTTPFactory( @@ -39,6 +41,7 @@ class Server(object): timeout=self.http_timeout, websocket_timeout=self.websocket_timeout, ping_interval=self.ping_interval, + ws_protocols=self.ws_protocols, ) if self.unix_socket: reactor.listenUNIX(self.unix_socket, self.factory) diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index fb7d095..9fbfb83 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -57,6 +57,10 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.error(traceback.format_exc()) raise + ws_protocol = clean_headers.get('sec-websocket-protocol') + if ws_protocol and ws_protocol in self.factory.protocols: + return ws_protocol + def onOpen(self): # Send news that this channel is open logger.debug("WebSocket open for %s", self.reply_channel)