diff --git a/channels/interfaces/http_twisted.py b/channels/interfaces/http_twisted.py new file mode 100644 index 0000000..fcee728 --- /dev/null +++ b/channels/interfaces/http_twisted.py @@ -0,0 +1,249 @@ +import time + +from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory +from twisted.python.compat import _PY3 +from twisted.web.http import HTTPFactory, HTTPChannel, Request, _respondToBadRequestAndDisconnect, parse_qs, _parseHeader +from twisted.protocols.policies import ProtocolWrapper +from twisted.internet import reactor + +from channels import Channel, channel_backends, DEFAULT_CHANNEL_BACKEND +from .websocket_autobahn import get_protocol, get_factory + + +WebsocketProtocol = get_protocol(WebSocketServerProtocol) + + +class WebRequest(Request): + """ + Request that either hands off information to channels, or offloads + to a WebSocket class. + + Does some extra processing over the normal Twisted Web request to separate + GET and POST out. + """ + + def __init__(self, *args, **kwargs): + Request.__init__(self, *args, **kwargs) + self.reply_channel = Channel.new_name("!http.response") + self.channel.factory.reply_protocols[self.reply_channel] = self + + def process(self): + # Get upgrade header + upgrade_header = None + if self.requestHeaders.hasHeader("Upgrade"): + upgrade_header = self.requestHeaders.getRawHeaders("Upgrade")[0] + # Is it WebSocket? IS IT?! + if upgrade_header == "websocket": + # Make WebSocket protocol to hand off to + protocol = self.channel.factory.ws_factory.buildProtocol(self.transport.getPeer()) + if not protocol: + # If protocol creation fails, we signal "internal server error" + self.setResponseCode(500) + self.finish() + # Port across transport + transport, self.transport = self.transport, None + if isinstance(transport, ProtocolWrapper): + # i.e. TLS is a wrapping protocol + transport.wrappedProtocol = protocol + else: + transport.protocol = protocol + protocol.makeConnection(transport) + # Re-inject request + if _PY3: + data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a' + for h in self.requestHeaders.getAllRawHeaders(): + data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a' + data += b"\x0d\x0a" + data += self.content.read() + else: + data = "%s %s HTTP/1.1\x0d\x0a" % (self.method, self.uri) + for h in self.requestHeaders.getAllRawHeaders(): + data += "%s: %s\x0d\x0a" % (h[0], ",".join(h[1])) + data += "\x0d\x0a" + protocol.dataReceived(data) + # Remove our HTTP reply channel association + self.channel.factory.reply_protocols[self.reply_channel] = None + self.reply_channel = None + # Boring old HTTP. + else: + # Send request message + Channel("http.request").send({ + "reply_channel": self.reply_channel, + "method": self.method, + "get": self.get, + "post": self.post, + "cookies": self.received_cookies, + "headers": {k: v[0] for k, v in self.requestHeaders.getAllRawHeaders()}, + "client": [self.client.host, self.client.port], + "server": [self.host.host, self.host.port], + "path": self.path, + }) + + def connectionLost(self, reason): + """ + Cleans up reply channel on close. + """ + if self.reply_channel: + del self.channel.factory.reply_protocols[self.reply_channel] + Request.connectionLost(self, reason) + + def serverResponse(self, message): + """ + Writes a received HTTP response back out to the transport. + """ + # Write code + self.setResponseCode(message['status']) + # Write headers + for header, value in message.get("headers", {}): + self.setHeader(header.encode("utf8"), value.encode("utf8")) + # Write cookies + for cookie in message.get("cookies"): + self.cookies.append(cookie.encode("utf8")) + # Write out body + if "content" in message: + Request.write(self, message['content'].encode("utf8")) + self.finish() + + def requestReceived(self, command, path, version): + """ + Called by channel when all data has been received. + Overridden because Twisted merges GET and POST into one thing by default. + """ + self.content.seek(0,0) + self.get = {} + self.post = {} + + self.method, self.uri = command, path + self.clientproto = version + x = self.uri.split(b'?', 1) + + print self.method + + # URI and GET args assignment + if len(x) == 1: + self.path = self.uri + else: + self.path, argstring = x + self.get = parse_qs(argstring, 1) + + # cache the client and server information, we'll need this later to be + # serialized and sent with the request so CGIs will work remotely + self.client = self.channel.transport.getPeer() + self.host = self.channel.transport.getHost() + + # Argument processing + ctype = self.requestHeaders.getRawHeaders(b'content-type') + if ctype is not None: + ctype = ctype[0] + + # Process POST data if present + if self.method == b"POST" and ctype: + mfd = b'multipart/form-data' + key, pdict = _parseHeader(ctype) + if key == b'application/x-www-form-urlencoded': + self.post.update(parse_qs(self.content.read(), 1)) + elif key == mfd: + try: + cgiArgs = cgi.parse_multipart(self.content, pdict) + + if _PY3: + # parse_multipart on Python 3 decodes the header bytes + # as iso-8859-1 and returns a str key -- we want bytes + # so encode it back + self.post.update({x.encode('iso-8859-1'): y + for x, y in cgiArgs.items()}) + else: + self.post.update(cgiArgs) + except: + # It was a bad request. + _respondToBadRequestAndDisconnect(self.channel.transport) + return + self.content.seek(0, 0) + + # Continue with rest of request handling + self.process() + + +class WebProtocol(HTTPChannel): + + requestFactory = WebRequest + + +class WebFactory(HTTPFactory): + + protocol = WebProtocol + + def __init__(self): + HTTPFactory.__init__(self) + # We track all sub-protocols for response channel mapping + self.reply_protocols = {} + # Make a factory for WebSocket protocols + self.ws_factory = WebSocketServerFactory("ws://127.0.0.1:8000") + self.ws_factory.protocol = WebsocketProtocol + self.ws_factory.reply_protocols = self.reply_protocols + + def reply_channels(self): + return self.reply_protocols.keys() + + def dispatch_reply(self, channel, message): + if channel.startswith("!http") and isinstance(self.reply_protocols[channel], WebRequest): + self.reply_protocols[channel].serverResponse(message) + elif channel.startswith("!websocket") and isinstance(self.reply_protocols[channel], WebsocketProtocol): + if message.get("content", None): + self.reply_protocols[channel].serverSend(**message) + if message.get("close", False): + self.reply_protocols[channel].serverClose() + else: + raise ValueError("Cannot dispatch message on channel %r" % channel) + + +class HttpTwistedInterface(object): + """ + Easy API to run a HTTP1 & WebSocket interface server using Twisted. + Integrates the channel backend by running it in a separate thread, using + the always-compatible polling style. + """ + + def __init__(self, channel_backend, port=8000): + self.channel_backend = channel_backend + self.port = port + + def run(self): + self.factory = WebFactory() + reactor.listenTCP(self.port, self.factory) + reactor.callInThread(self.backend_reader) + #reactor.callLater(1, self.keepalive_sender) + reactor.run() + + def backend_reader(self): + """ + Run in a separate thread; reads messages from the backend. + """ + while True: + channels = self.factory.reply_channels() + # Quit if reactor is stopping + if not reactor.running: + return + # Don't do anything if there's no channels to listen on + if channels: + channel, message = self.channel_backend.receive_many(channels) + else: + time.sleep(0.1) + continue + # Wait around if there's nothing received + if channel is None: + time.sleep(0.05) + continue + # Deal with the message + self.factory.dispatch_reply(channel, message) + + def keepalive_sender(self): + """ + Sends keepalive messages for open WebSockets every + (channel_backend expiry / 2) seconds. + """ + expiry_window = int(self.channel_backend.expiry / 2) + for protocol in self.factory.reply_protocols.values(): + if time.time() - protocol.last_keepalive > expiry_window: + protocol.sendKeepalive() + reactor.callLater(1, self.keepalive_sender) diff --git a/channels/interfaces/websocket_asyncio.py b/channels/interfaces/websocket_asyncio.py index 899bfcf..c41871e 100644 --- a/channels/interfaces/websocket_asyncio.py +++ b/channels/interfaces/websocket_asyncio.py @@ -65,7 +65,7 @@ class WebsocketAsyncioInterface(object): (channel_backend expiry / 2) seconds. """ expiry_window = int(self.channel_backend.expiry / 2) - for protocol in self.factory.protocols.values(): + for protocol in self.factory.reply_protocols.values(): if time.time() - protocol.last_keepalive > expiry_window: protocol.sendKeepalive() if self.loop.is_running(): diff --git a/channels/interfaces/websocket_autobahn.py b/channels/interfaces/websocket_autobahn.py index eff1bc6..cd3dd3e 100644 --- a/channels/interfaces/websocket_autobahn.py +++ b/channels/interfaces/websocket_autobahn.py @@ -26,7 +26,7 @@ def get_protocol(base): self.reply_channel = Channel.new_name("!websocket.send") self.request_info["reply_channel"] = self.reply_channel self.last_keepalive = time.time() - self.factory.protocols[self.reply_channel] = self + self.factory.reply_protocols[self.reply_channel] = self # Send news that this channel is open Channel("websocket.connect").send(self.request_info) @@ -61,7 +61,7 @@ def get_protocol(base): def onClose(self, wasClean, code, reason): if hasattr(self, "reply_channel"): - del self.factory.protocols[self.reply_channel] + del self.factory.reply_protocols[self.reply_channel] Channel("websocket.disconnect").send({ "reply_channel": self.reply_channel, }) @@ -90,15 +90,15 @@ def get_factory(base): def __init__(self, *args, **kwargs): super(InterfaceFactory, self).__init__(*args, **kwargs) - self.protocols = {} + self.reply_protocols = {} def reply_channels(self): - return self.protocols.keys() + return self.reply_protocols.keys() def dispatch_send(self, channel, message): if message.get("content", None): - self.protocols[channel].serverSend(**message) + self.reply_protocols[channel].serverSend(**message) if message.get("close", False): - self.protocols[channel].serverClose() + self.reply_protocols[channel].serverClose() return InterfaceFactory diff --git a/channels/interfaces/websocket_twisted.py b/channels/interfaces/websocket_twisted.py index e8122a7..f8a4c4b 100644 --- a/channels/interfaces/websocket_twisted.py +++ b/channels/interfaces/websocket_twisted.py @@ -55,7 +55,7 @@ class WebsocketTwistedInterface(object): (channel_backend expiry / 2) seconds. """ expiry_window = int(self.channel_backend.expiry / 2) - for protocol in self.factory.protocols.values(): + for protocol in self.factory.reply_protocols.values(): if time.time() - protocol.last_keepalive > expiry_window: protocol.sendKeepalive() reactor.callLater(1, self.keepalive_sender) diff --git a/channels/management/commands/runallserver.py b/channels/management/commands/runallserver.py new file mode 100644 index 0000000..1762be5 --- /dev/null +++ b/channels/management/commands/runallserver.py @@ -0,0 +1,26 @@ +import time +from django.core.management import BaseCommand, CommandError +from channels import channel_backends, DEFAULT_CHANNEL_BACKEND + + +class Command(BaseCommand): + + def add_arguments(self, parser): + parser.add_argument('port', nargs='?', + help='Optional port number') + + def handle(self, *args, **options): + # Get the backend to use + channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND] + if channel_backend.local_only: + raise CommandError( + "You have a process-local channel backend configured, and so cannot run separate interface servers.\n" + "Configure a network-based backend in CHANNEL_BACKENDS to use this command." + ) + # Run the interface + port = int(options.get("port", None) or 8000) + from channels.interfaces.http_twisted import HttpTwistedInterface + self.stdout.write("Running twisted/Autobahn HTTP & WebSocket interface server") + self.stdout.write(" Channel backend: %s" % channel_backend) + self.stdout.write(" Listening on: 0.0.0.0:%i" % port) + HttpTwistedInterface(channel_backend=channel_backend, port=port).run() diff --git a/channels/request.py b/channels/request.py index 0c5de55..01be5b5 100644 --- a/channels/request.py +++ b/channels/request.py @@ -20,6 +20,14 @@ def encode_request(request): "path": request.path, "method": request.method, "reply_channel": request.reply_channel, + "server": [ + request.META.get("SERVER_NAME", None), + request.META.get("SERVER_PORT", None), + ], + "client": [ + request.META.get("REMOTE_ADDR", None), + request.META.get("REMOTE_PORT", None), + ], } return value diff --git a/docs/message-standards.rst b/docs/message-standards.rst index 665ac2c..44d4852 100644 --- a/docs/message-standards.rst +++ b/docs/message-standards.rst @@ -33,8 +33,8 @@ Standard channel name is ``http.request``. Contains the following keys: -* get: List of (key, value) tuples of GET variables (keys and values are strings) -* post: List of (key, value) tuples of POST variables (keys and values are strings) +* get: Dict of {key: [value, ...]} of GET variables (keys and values are strings) +* post: Dict of {key: [value, ...]} of POST variables (keys and values are strings) * cookies: Dict of cookies as {cookie_name: cookie_value} (names and values are strings) * meta: Dict of HTTP headers and info as defined in the Django Request docs (names and values are strings) * path: String, full path to the requested page, without query string or domain