mirror of
https://github.com/django/daphne.git
synced 2024-11-22 07:56:34 +03:00
Merge pull request #7 from fcurella/ws_protocols
add support for websocket protocols
This commit is contained in:
commit
3d069c0427
|
@ -70,6 +70,13 @@ class CommandLineInterface(object):
|
||||||
'channel_layer',
|
'channel_layer',
|
||||||
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--ws-protocol',
|
||||||
|
nargs='*',
|
||||||
|
dest='ws_protocols',
|
||||||
|
help='The WebSocket protocols you wish to support',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def entrypoint(cls):
|
def entrypoint(cls):
|
||||||
|
@ -123,4 +130,5 @@ class CommandLineInterface(object):
|
||||||
http_timeout=args.http_timeout,
|
http_timeout=args.http_timeout,
|
||||||
ping_interval=args.ping_interval,
|
ping_interval=args.ping_interval,
|
||||||
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
||||||
|
ws_protocols=args.ws_protocols,
|
||||||
).run()
|
).run()
|
||||||
|
|
|
@ -222,7 +222,7 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
|
|
||||||
protocol = HTTPProtocol
|
protocol = HTTPProtocol
|
||||||
|
|
||||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20):
|
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ws_protocols=None):
|
||||||
http.HTTPFactory.__init__(self)
|
http.HTTPFactory.__init__(self)
|
||||||
self.channel_layer = channel_layer
|
self.channel_layer = channel_layer
|
||||||
self.action_logger = action_logger
|
self.action_logger = action_logger
|
||||||
|
@ -232,7 +232,7 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
# We track all sub-protocols for response channel mapping
|
# We track all sub-protocols for response channel mapping
|
||||||
self.reply_protocols = {}
|
self.reply_protocols = {}
|
||||||
# Make a factory for WebSocket protocols
|
# Make a factory for WebSocket protocols
|
||||||
self.ws_factory = WebSocketFactory(self)
|
self.ws_factory = WebSocketFactory(self, protocols=ws_protocols)
|
||||||
self.ws_factory.protocol = WebSocketProtocol
|
self.ws_factory.protocol = WebSocketProtocol
|
||||||
self.ws_factory.reply_protocols = self.reply_protocols
|
self.ws_factory.reply_protocols = self.reply_protocols
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ class Server(object):
|
||||||
http_timeout=120,
|
http_timeout=120,
|
||||||
websocket_timeout=None,
|
websocket_timeout=None,
|
||||||
ping_interval=20,
|
ping_interval=20,
|
||||||
|
ws_protocols=None,
|
||||||
):
|
):
|
||||||
self.channel_layer = channel_layer
|
self.channel_layer = channel_layer
|
||||||
self.host = host
|
self.host = host
|
||||||
|
@ -31,6 +32,7 @@ class Server(object):
|
||||||
# If they did not provide a websocket timeout, default it to the
|
# If they did not provide a websocket timeout, default it to the
|
||||||
# channel layer's group_expiry value if present, or one day if not.
|
# channel layer's group_expiry value if present, or one day if not.
|
||||||
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
||||||
|
self.ws_protocols = ws_protocols
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.factory = HTTPFactory(
|
self.factory = HTTPFactory(
|
||||||
|
@ -39,6 +41,7 @@ class Server(object):
|
||||||
timeout=self.http_timeout,
|
timeout=self.http_timeout,
|
||||||
websocket_timeout=self.websocket_timeout,
|
websocket_timeout=self.websocket_timeout,
|
||||||
ping_interval=self.ping_interval,
|
ping_interval=self.ping_interval,
|
||||||
|
ws_protocols=self.ws_protocols,
|
||||||
)
|
)
|
||||||
if self.unix_socket:
|
if self.unix_socket:
|
||||||
reactor.listenUNIX(self.unix_socket, self.factory)
|
reactor.listenUNIX(self.unix_socket, self.factory)
|
||||||
|
|
|
@ -57,6 +57,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
ws_protocol = clean_headers.get('sec-websocket-protocol')
|
||||||
|
if ws_protocol and ws_protocol in self.factory.protocols:
|
||||||
|
return ws_protocol
|
||||||
|
|
||||||
def onOpen(self):
|
def onOpen(self):
|
||||||
# Send news that this channel is open
|
# Send news that this channel is open
|
||||||
logger.debug("WebSocket open for %s", self.reply_channel)
|
logger.debug("WebSocket open for %s", self.reply_channel)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user