diff --git a/channels/interfaces/websocket_asyncio.py b/channels/interfaces/websocket_asyncio.py new file mode 100644 index 0000000..5dc3fe3 --- /dev/null +++ b/channels/interfaces/websocket_asyncio.py @@ -0,0 +1,155 @@ +import asyncio +import django +import time + +from autobahn.asyncio.websocket import WebSocketServerProtocol, WebSocketServerFactory +from collections import deque + +from channels import Channel, channel_backends, DEFAULT_CHANNEL_BACKEND + + +class InterfaceProtocol(WebSocketServerProtocol): + """ + Protocol which supports WebSockets and forwards incoming messages to + the websocket channels. + """ + + def onConnect(self, request): + self.channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND] + self.request_info = { + "path": request.path, + "get": request.params, + } + + def onOpen(self): + # Make sending channel + self.reply_channel = Channel.new_name("!websocket.send") + self.request_info["reply_channel"] = self.reply_channel + self.last_keepalive = time.time() + self.factory.protocols[self.reply_channel] = self + # Send news that this channel is open + Channel("websocket.connect").send(self.request_info) + + def onMessage(self, payload, isBinary): + if isBinary: + Channel("websocket.receive").send(dict( + self.request_info, + content = payload, + binary = True, + )) + else: + Channel("websocket.receive").send(dict( + self.request_info, + content = payload.decode("utf8"), + binary = False, + )) + + def serverSend(self, content, binary=False, **kwargs): + """ + Server-side channel message to send a message. + """ + if binary: + self.sendMessage(content, binary) + else: + self.sendMessage(content.encode("utf8"), binary) + + def serverClose(self): + """ + Server-side channel message to close the socket + """ + self.sendClose() + + def onClose(self, wasClean, code, reason): + if hasattr(self, "reply_channel"): + del self.factory.protocols[self.reply_channel] + Channel("websocket.disconnect").send(self.request_info) + + def sendKeepalive(self): + """ + Sends a keepalive packet on the keepalive channel. + """ + Channel("websocket.keepalive").send(self.request_info) + self.last_keepalive = time.time() + + +class InterfaceFactory(WebSocketServerFactory): + """ + Factory which keeps track of its open protocols' receive channels + and can dispatch to them. + """ + + # TODO: Clean up dead protocols if needed? + + def __init__(self, *args, **kwargs): + super(InterfaceFactory, self).__init__(*args, **kwargs) + self.protocols = {} + + def reply_channels(self): + return self.protocols.keys() + + def dispatch_send(self, channel, message): + if message.get("close", False): + self.protocols[channel].serverClose() + else: + self.protocols[channel].serverSend(**message) + + +class WebsocketAsyncioInterface(object): + """ + Easy API to run a WebSocket interface server using Twisted. + Integrates the channel backend by running it in a separate thread, using + the always-compatible polling style. + """ + + def __init__(self, channel_backend, port=9000): + self.channel_backend = channel_backend + self.port = port + + def run(self): + self.factory = InterfaceFactory("ws://0.0.0.0:%i" % self.port, debug=False) + self.factory.protocol = InterfaceProtocol + self.loop = asyncio.get_event_loop() + coro = self.loop.create_server(self.factory, '0.0.0.0', 9000) + server = self.loop.run_until_complete(coro) + self.loop.run_in_executor(None, self.backend_reader) + self.loop.call_later(1, self.keepalive_sender) + try: + self.loop.run_forever() + except KeyboardInterrupt: + pass + finally: + server.close() + self.loop.close() + + def backend_reader(self): + """ + Run in a separate thread; reads messages from the backend. + """ + while True: + channels = self.factory.reply_channels() + # Quit if reactor is stopping + if not self.loop.is_running(): + return + # Don't do anything if there's no channels to listen on + if channels: + channel, message = self.channel_backend.receive_many(channels) + else: + time.sleep(0.1) + continue + # Wait around if there's nothing received + if channel is None: + time.sleep(0.05) + continue + # Deal with the message + self.factory.dispatch_send(channel, message) + + def keepalive_sender(self): + """ + Sends keepalive messages for open WebSockets every + (channel_backend expiry / 2) seconds. + """ + expiry_window = int(self.channel_backend.expiry / 2) + for protocol in self.factory.protocols.values(): + if time.time() - protocol.last_keepalive > expiry_window: + protocol.sendKeepalive() + self.loop.call_later(1, self.keepalive_sender) diff --git a/channels/management/commands/runwsserver.py b/channels/management/commands/runwsserver.py index 7b4d296..b1525d9 100644 --- a/channels/management/commands/runwsserver.py +++ b/channels/management/commands/runwsserver.py @@ -1,7 +1,6 @@ import time from django.core.management import BaseCommand, CommandError from channels import channel_backends, DEFAULT_CHANNEL_BACKEND -from channels.interfaces.websocket_twisted import WebsocketTwistedInterface class Command(BaseCommand): @@ -20,7 +19,17 @@ class Command(BaseCommand): ) # Run the interface port = options.get("port", None) or 9000 - self.stdout.write("Running Twisted/Autobahn WebSocket interface server") - self.stdout.write(" Channel backend: %s" % channel_backend) - self.stdout.write(" Listening on: ws://0.0.0.0:%i" % port) - WebsocketTwistedInterface(channel_backend=channel_backend, port=port).run() + try: + import asyncio + except ImportError: + from channels.interfaces.websocket_twisted import WebsocketTwistedInterface + self.stdout.write("Running Twisted/Autobahn WebSocket interface server") + self.stdout.write(" Channel backend: %s" % channel_backend) + self.stdout.write(" Listening on: ws://0.0.0.0:%i" % port) + WebsocketTwistedInterface(channel_backend=channel_backend, port=port).run() + else: + from channels.interfaces.websocket_asyncio import WebsocketAsyncioInterface + self.stdout.write("Running asyncio/Autobahn WebSocket interface server") + self.stdout.write(" Channel backend: %s" % channel_backend) + self.stdout.write(" Listening on: ws://0.0.0.0:%i" % port) + WebsocketAsyncioInterface(channel_backend=channel_backend, port=port).run()