diff --git a/.travis.yml b/.travis.yml index 6f34248..3c6acac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,5 +5,5 @@ python: - "3.5" install: - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install unittest2; fi - - pip install asgiref twisted autobahn + - pip install asgiref twisted autobahn h2 script: if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then python -m unittest2; else python -m unittest; fi diff --git a/daphne/cli.py b/daphne/cli.py index 8763c87..d075286 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -5,6 +5,8 @@ import importlib from .server import Server from .access import AccessLogGenerator +from .http_protocol import HTTPFactory +from .http2_protocol import H2Factory logger = logging.getLogger(__name__) @@ -66,11 +68,31 @@ class CommandLineInterface(object): help='The number of seconds a WebSocket must be idle before a keepalive ping is sent', default=20, ) + self.parser.add_argument( + '--h2', + action='store_true', + help="enable HTTP/2" + ) + + self.parser.add_argument( + '--sslcert', + action="store", + help="path to ssl certificate file" + ) + + self.parser.add_argument( + '--sslkey', + action="store", + help="path to ssl private key file" + ) + + self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', ) + @classmethod def entrypoint(cls): """ @@ -115,12 +137,21 @@ class CommandLineInterface(object): (args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)), args.channel_layer, ) + if args.h2 : + factory_class = H2Factory + else : + factory_class = HTTPFactory + Server( channel_layer=channel_layer, + factory_class=factory_class, host=args.host, port=args.port, unix_socket=args.unix_socket, http_timeout=args.http_timeout, ping_interval=args.ping_interval, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, + ssl_certificate=args.sslcert, + ssl_key=args.sslkey + ).run() diff --git a/daphne/http2_protocol.py b/daphne/http2_protocol.py new file mode 100644 index 0000000..449ed43 --- /dev/null +++ b/daphne/http2_protocol.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +import functools + +from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.protocol import Protocol, Factory +from twisted.internet import endpoints + +from h2.connection import H2Connection +from h2.events import ( + RequestReceived, DataReceived, WindowUpdated +) +import time + + +def close_file(file, d): + file.close() + + +READ_CHUNK_SIZE = 8192 + +class H2Request(object): + def __init__(self, protocol, id, reply_channel, body_channel=None) : + self.protocol = protocol + self.stream_id = id + self.start_time = time.time() + self.reply_channel = reply_channel + self.body_channel = body_channel + self.response_started = False + self.headers = {} + self._header_sent = False # have header message been sent to channel layer ? + + def setHeaders(self, headers) : + self.headers = headers + self.body = b"" + + def sendHeaders(self): + + path = self.headers[':path'] + query_string = b"" + if "?" in path: # h2 makes path a unicode + path, query_string = path.encode().split(b"?", 1) + + # clean up ':' prefixed headers + headers_ = {} + for k,v in self.headers.items() : + if not k.startswith(':'): + headers_[k] = v + + # not post : wait for body before sending message + self.protocol.factory.channel_layer.send("http.request", { + "reply_channel": self.reply_channel, + "http_version": "2.0", # \o/ + "scheme": "http", # should be read from env/proxys headers ?? + "method" : self.headers[':method'], + "path" : path, # asgi expects these as bytes + "query_string" : query_string, + "headers": headers_, + "body": self.body, # this is populated on DataReceived event + "client": [self.protocol.transport.getHost().host, + self.protocol.transport.getHost().port], + }) + + self._header_send = True + + def serverResponse(self, message ): + if "status" in message : + assert(not self.response_started) + self.response_started = True + self.protocol.makeResponse(self.stream_id, message) + # only if we are done + else : + assert(self.response_started) + self.protocol.sendData(self.stream_id, + message["content"], + message["more_content"]) + + if(not message.get("more_content", False)) : + del self.protocol.factory.reply_protocols[self.reply_channel] + + + + def dataReceived(self, data) : + """ chunk of body received """ + if(self._header_sent and self.body_channel) : + self.protocol.factory.channel_layer.send(self.body_channel, { + "content": data, + "closed": False, # send a True to signal interruption of requests + "more_content": False, # we just can't know that .. + }) + else : + print("Barf!") + + def duration(self): + return time.time() - self.start_time + + def basic_error(self): + pass + +class H2Protocol(Protocol): + def __init__(self, factory): + self.conn = H2Connection(client_side=False) + self.factory = factory + self.known_proto = None + #self.root = root + self.requests = {} # ongoing requests + self._flow_control_deferreds = {} + + def connectionMade(self): + self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + + def dataReceived(self, data): + if not self.known_proto: + self.known_proto = True + + events = self.conn.receive_data(data) + if self.conn.data_to_send: + self.transport.write(self.conn.data_to_send()) + + for event in events: + if isinstance(event, RequestReceived): + self.requestReceived(event.headers, event.stream_id) + elif isinstance(event, DataReceived): + self.dataFrameReceived(event.stream_id, event.data) + elif isinstance(event, WindowUpdated): + self.windowUpdated(event) + + def makeResponse(self, stream_id, message) : + print("responding", message) + response_headers = [ + (':status', str(message["status"])), + ('server', 'twisted-h2'), + ("status_text", message.get("status_text", "")), + ] + for header, value in message.get("headers", []) : + response_headers.append((header, value)) + + self.conn.send_headers(stream_id, response_headers) + self.transport.write(self.conn.data_to_send()) + + # write content .. Chnk this !! + more_content = message.get('more_content', False) + # that's a twisted deferred, if you don't add a call back, + # this gets discarded + d = self.sendData(stream_id, message["content"], more_content) + d.addErrback(lambda e: print("error in send data", e)) + + + def requestReceived(self, headers, stream_id): + headers = dict(headers) # Invalid conversion, fix later. + + reply_channel = self.factory.channel_layer.new_channel("http.response!") + + body_channel = None + if(headers[':method'] == 'POST'): + body_channel = self.factory.channel_layer.new_channel("http.request.body!") + # body_channel = + req = H2Request(self, stream_id, reply_channel, body_channel) + req.setHeaders(headers) + + self.requests[stream_id] = req + self.factory.reply_protocols[reply_channel] = req + + # send the request to channel layer, or wait for body + req.sendHeaders() + + + @inlineCallbacks + def sendData(self, stream_id, data, more_content=False): + # chunks and enqueue data + send_more = True + msg_size = len(data) + offset = 0 + while send_more : + print("waigint for flow control") + while not self.conn.remote_flow_control_window(stream_id) : + # do we have a flow window ? + yield self.wait_for_flow_control(stream_id) + + chunk_size = min(self.conn.remote_flow_control_window(stream_id),READ_CHUNK_SIZE) + + # hopefully, both are bigger than message data + if (msg_size - offset) < chunk_size : + send_more = False + end_chunk = offset + chunk_size + 1 + else : + end_chunk = msg_size + 1 + + chunk = data[offset:end_chunk] + # if more_content, keep request active + done = not ( send_more or more_content) + self.conn.send_data(stream_id, chunk, done) + self.transport.write(self.conn.data_to_send()) + + + def wait_for_flow_control(self, stream_id): + d = Deferred() + self._flow_control_deferreds[stream_id] = d + return d + + def dataFrameReceived(self, stream_id, data): + self.requests[stream_id].dataReceived(data) + + def windowUpdated(self, event): + stream_id = event.stream_id + print("window flow ctrl", stream_id) + if stream_id and stream_id in self._flow_control_deferreds: + d = self._flow_control_deferreds.pop(stream_id) + d.callback(event.delta) + elif not stream_id: + # fire them all.. + for d in self._flow_control_deferreds.values(): + d.callback(event.delta) + self._flow_control_deferreds = {} + return + + +class H2Factory(Factory): + + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): + self.channel_layer = channel_layer + self.action_logger = action_logger + self.timeout = timeout + self.websocket_timeout = websocket_timeout + self.ping_interval = ping_interval + # We track all sub-protocols for response channel mapping + self.reply_protocols = {} + # Make a factory for WebSocket protocols + # self.ws_factory = WebSocketFactory(self) + # self.ws_factory.protocol = WebSocketProtocol + # self.ws_factory.reply_protocols = self.reply_protocols + + def buildProtocol(self, addr): + return H2Protocol(self) + + + # copy pasta from http_protocol + def dispatch_reply(self, channel, message): + if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request): + self.reply_protocols[channel].serverResponse(message) + # elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): + # if message.get("bytes", None): + # self.reply_protocols[channel].serverSend(message["bytes"], True) + # if message.get("text", None): + # self.reply_protocols[channel].serverSend(message["text"], False) + # if message.get("close", False): + # self.reply_protocols[channel].serverClose() + else: + raise ValueError("Cannot dispatch message on channel %r" % channel) + + # copy pasta from http protocol + def reply_channels(self): + return self.reply_protocols.keys() + + + def log_action(self, protocol, action, details): + """ + Dispatches to any registered action logger, if there is one. + """ + if self.action_logger: + self.action_logger(protocol, action, details) + + def check_timeouts(self): + """ + Runs through all HTTP protocol instances and times them out if they've + taken too long (and so their message is probably expired) + """ + for protocol in list(self.reply_protocols.values()): + # Web timeout checking + if isinstance(protocol, H2Request) and protocol.duration() > self.timeout: + protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") + # WebSocket timeout checking and keepalive ping sending + #elif isinstance(protocol, WebSocketProtocol): + # Timeout check + # if protocol.duration() > self.websocket_timeout: + # protocol.serverClose() + # Ping check + # else: + # protocol.check_ping() diff --git a/daphne/server.py b/daphne/server.py index 90bc6f2..7eb321e 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -1,5 +1,5 @@ import logging -from twisted.internet import reactor +from twisted.internet import reactor, ssl, endpoints from .http_protocol import HTTPFactory @@ -11,6 +11,7 @@ class Server(object): def __init__( self, channel_layer, + factory_class, # that's a factory factory ... meeeh host="127.0.0.1", port=8000, unix_socket=None, @@ -19,6 +20,8 @@ class Server(object): http_timeout=120, websocket_timeout=None, ping_interval=20, + ssl_certificate = None, + ssl_key = None ): self.channel_layer = channel_layer self.host = host @@ -31,19 +34,41 @@ class Server(object): # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) + self.factory_class = factory_class + self.ssl_certificate = ssl_certificate + self.ssl_key = ssl_key def run(self): - self.factory = HTTPFactory( + self.factory = self.factory_class( self.channel_layer, self.action_logger, timeout=self.http_timeout, websocket_timeout=self.websocket_timeout, ping_interval=self.ping_interval, ) - if self.unix_socket: - reactor.listenUNIX(self.unix_socket, self.factory) - else: - reactor.listenTCP(self.port, self.factory, interface=self.host) + + if self.ssl_certificate : + + from OpenSSL import crypto + + with open(self.ssl_certificate, 'r') as f: + cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + with open(self.ssl_key, 'r') as f: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) + + opts = ssl.CertificateOptions( + privateKey= key, + certificate=cert, + acceptableProtocols=[b'h2'] + ) + + endpt = endpoints.SSL4ServerEndpoint(reactor, self.port, opts, backlog=128) + endpt.listen(self.factory) + else : + if self.unix_socket: + reactor.listenUNIX(self.unix_socket, self.factory) + else: + reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.callLater(0, self.backend_reader) reactor.callLater(2, self.timeout_checker) reactor.run(installSignalHandlers=self.signal_handlers) diff --git a/daphne/tests/test_h2.py b/daphne/tests/test_h2.py new file mode 100644 index 0000000..ee93a30 --- /dev/null +++ b/daphne/tests/test_h2.py @@ -0,0 +1,83 @@ +from unittest import TestCase +from asgiref.inmemory import ChannelLayer +from twisted.test import proto_helpers + +from ..http2_protocol import H2Factory +from h2.connection import H2Connection +import h2.events + +class TestH2Protocol(TestCase): + """ + Tests that the HTTP protocol class correctly generates and parses messages. + """ + + def setUp(self): + self.channel_layer = ChannelLayer() + self.factory = H2Factory(self.channel_layer) + self.proto = self.factory.buildProtocol(('127.0.0.1', 0)) + self.tr = proto_helpers.StringTransport() + self.proto.makeConnection(self.tr) + + + def assertStartsWith(self, data, prefix): + real_prefix = data[:len(prefix)] + self.assertEqual(real_prefix, prefix) + + def test_basic(self): + """ + Tests basic HTTP parsing + """ + # Send a simple request to the protocol + + conn = H2Connection() + conn.initiate_connection() + #self.tr.write(conn.data_to_send()) + self.proto.dataReceived(conn.data_to_send()) + conn.send_headers(1, [ + (':method', 'GET'), + (':path', '/test/?foo=bar'), + ('user-agent', 'hyper-h2/yo'), + ], end_stream=True) + self.proto.dataReceived(conn.data_to_send()) + + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['http_version'], "2.0") + self.assertEqual(message['method'], "GET") + self.assertEqual(message['scheme'], "http") + self.assertEqual(message['path'], b"/test/") + self.assertEqual(message['query_string'], b"foo=bar") + # self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")]) + self.assertFalse(message.get("body", None)) + self.assertTrue(message['reply_channel']) + + # Send back an example response + self.factory.dispatch_reply( + message['reply_channel'], + { + "status": 201, + "status_text": b"Created", + "content": b"OH HAI", + "headers": [[b"X-Test", b"Boom!"]], + } + ) + # Make sure that comes back right on the protocol + data = self.tr.value() + evs = conn.receive_data(data) + # we should see a ResponseReceived + hasResponse = False + for e in evs : + if isinstance(e, h2.events.ResponseReceived) : + hasResponse = True + headers = dict(e.headers) + self.assertEqual(headers["x-test"],"Boom!") + self.assertEqual(headers[":status"], "201") + + if isinstance(e, h2.events.DataReceived): + self.assertEqual(e.data, b"OH HAI") + self.assertTrue(hasResponse) + + # a DataReceived + # a StreamEnded ?? + + # self.assertEqual(evs) + # self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n") diff --git a/setup.py b/setup.py index 5a58390..eddd2fb 100755 --- a/setup.py +++ b/setup.py @@ -24,6 +24,8 @@ setup( 'asgiref>=0.10', 'twisted>=15.5', 'autobahn>=0.12', + 'h2>=2.2', + 'pyOpenSSL' # optionnal ?? ], entry_points={'console_scripts': [ 'daphne = daphne.cli:CommandLineInterface.entrypoint',