From 854f25b7bd56032c535a52937d7e0cfe64ff19e0 Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Sun, 3 Apr 2016 16:17:04 +0200 Subject: [PATCH 1/5] initial proof of concept, works on simple GET request (no streaming body, no streaming reply, no SSL ) --- daphne/cli.py | 14 +++ daphne/http2_protocol.py | 194 +++++++++++++++++++++++++++++++++++++++ daphne/server.py | 6 +- daphne/tests/test_h2.py | 83 +++++++++++++++++ 4 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 daphne/http2_protocol.py create mode 100644 daphne/tests/test_h2.py diff --git a/daphne/cli.py b/daphne/cli.py index 8763c87..415f547 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -5,6 +5,8 @@ import importlib from .server import Server from .access import AccessLogGenerator +from .http_protocol import HTTPFactory +from .http2_protocol import H2Factory logger = logging.getLogger(__name__) @@ -66,6 +68,12 @@ class CommandLineInterface(object): help='The number of seconds a WebSocket must be idle before a keepalive ping is sent', default=20, ) + self.parser.add_argument( + '--h2', + action='store_true', + help="enable HTTP/2" + ) + self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', @@ -115,8 +123,14 @@ class CommandLineInterface(object): (args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)), args.channel_layer, ) + if args.h2 : + factory_class = H2Factory + else : + factory_class = HTTPFactory + Server( channel_layer=channel_layer, + factory_class=factory_class, host=args.host, port=args.port, unix_socket=args.unix_socket, diff --git a/daphne/http2_protocol.py b/daphne/http2_protocol.py new file mode 100644 index 0000000..4be6faf --- /dev/null +++ b/daphne/http2_protocol.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +import functools + +from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.protocol import Protocol, Factory +from twisted.internet import endpoints + +from h2.connection import H2Connection +from h2.events import ( + RequestReceived, DataReceived, WindowUpdated +) +import time + + +def close_file(file, d): + file.close() + + +READ_CHUNK_SIZE = 8192 + +class H2Request(object): + def __init__(self, protocol, id, reply_channel, body_channel=None) : + self.protocol = protocol + self.stream_id = id + self.start_time = time.time() + self.reply_channel = reply_channel + self.body_channel = body_channel + + def serverResponse(self, message ): + print(message) + self.protocol.makeResponse(self.stream_id, message) + del self.protocol.factory.reply_protocols[self.reply_channel] + + def DataReceived(self, data) : + """ chunk of body """ + self.protocol.factory.channel_layer.send(self.body_channel, { + content: data, + closed: False, # send a True to signal interruption of requests + more_content: False, + }) + + def duration(self): + return time.time() - self.start_time + + def basic_error(self): + pass + +class H2Protocol(Protocol): + def __init__(self, factory): + self.conn = H2Connection(client_side=False) + self.factory = factory + self.known_proto = None + #self.root = root + self.requests = {} # ongoing requests + self._flow_control_deferreds = {} + + def connectionMade(self): + self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + + def dataReceived(self, data): + if not self.known_proto: + self.known_proto = True + + events = self.conn.receive_data(data) + if self.conn.data_to_send: + self.transport.write(self.conn.data_to_send()) + + for event in events: + if isinstance(event, RequestReceived): + self.requestReceived(event.headers, event.stream_id) + elif isinstance(event, DataReceived): + self.dataFrameReceived(event.stream_id, event.data) + #elif isinstance(event, WindowUpdated): + # self.windowUpdated(event) + + def makeResponse(self, stream_id, message) : + + response_headers = [ + (':status', str(message["status"])), + ('server', 'twisted-h2'), + ("status_text", message.get("status_text", "")), + ] + for header, value in message.get("headers", []) : + response_headers.append((header, value)) + + self.conn.send_headers(stream_id, response_headers) + self.transport.write(self.conn.data_to_send()) + + # write content .. Chnk this !! + self.conn.send_data(stream_id, message["content"], True) + self.transport.write(self.conn.data_to_send()) + + + + def requestReceived(self, headers, stream_id): + headers = dict(headers) # Invalid conversion, fix later. + + reply_channel = self.factory.channel_layer.new_channel("http.response!") + + # how do we know if there's a pending body ?? + # body_channel = self.factory.channel_layer.new_channel("http.request.body!") + req = H2Request(self, stream_id, reply_channel, None) + + self.requests[stream_id] = req + self.factory.reply_protocols[reply_channel] = req + + path = headers[':path'] + query_string = b"" + if "?" in path: # h2 makes path a unicode + path, query_string = path.encode().split(b"?", 1) + + self.factory.channel_layer.send("http.request", { + "reply_channel": reply_channel, + "http_version": "2.0", # \o/ + "scheme": "http", # should be read from env/proxys headers ?? + "method" : headers[':method'], + "path" : path, # asgi expects these as bytes + "query_string" : query_string, + "headers": headers, + "body": b"", # this is populated on DataReceived event + "client": [self.transport.getHost().host, self.transport.getHost().port], + }) + + + + def dataFrameReceived(self, stream_id, data): + self.requests[stream_id].dataReceived(data) + + + + +class H2Factory(Factory): + + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): + self.channel_layer = channel_layer + self.action_logger = action_logger + self.timeout = timeout + self.websocket_timeout = websocket_timeout + self.ping_interval = ping_interval + # We track all sub-protocols for response channel mapping + self.reply_protocols = {} + # Make a factory for WebSocket protocols + # self.ws_factory = WebSocketFactory(self) + # self.ws_factory.protocol = WebSocketProtocol + # self.ws_factory.reply_protocols = self.reply_protocols + + def buildProtocol(self, addr): + return H2Protocol(self) + + + # copy pasta from http_protocol + def dispatch_reply(self, channel, message): + if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request): + self.reply_protocols[channel].serverResponse(message) + # elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): + # if message.get("bytes", None): + # self.reply_protocols[channel].serverSend(message["bytes"], True) + # if message.get("text", None): + # self.reply_protocols[channel].serverSend(message["text"], False) + # if message.get("close", False): + # self.reply_protocols[channel].serverClose() + else: + raise ValueError("Cannot dispatch message on channel %r" % channel) + + # copy pasta from http protocol + def reply_channels(self): + return self.reply_protocols.keys() + + + def log_action(self, protocol, action, details): + """ + Dispatches to any registered action logger, if there is one. + """ + if self.action_logger: + self.action_logger(protocol, action, details) + + def check_timeouts(self): + """ + Runs through all HTTP protocol instances and times them out if they've + taken too long (and so their message is probably expired) + """ + for protocol in list(self.reply_protocols.values()): + # Web timeout checking + if isinstance(protocol, H2Request) and protocol.duration() > self.timeout: + protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") + # WebSocket timeout checking and keepalive ping sending + #elif isinstance(protocol, WebSocketProtocol): + # Timeout check + # if protocol.duration() > self.websocket_timeout: + # protocol.serverClose() + # Ping check + # else: + # protocol.check_ping() diff --git a/daphne/server.py b/daphne/server.py index 90bc6f2..1188f66 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -11,6 +11,7 @@ class Server(object): def __init__( self, channel_layer, + factory_class, # that's a factory factory ... meeeh host="127.0.0.1", port=8000, unix_socket=None, @@ -31,9 +32,10 @@ class Server(object): # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) - + self.factory_class = factory_class + def run(self): - self.factory = HTTPFactory( + self.factory = self.factory_class( self.channel_layer, self.action_logger, timeout=self.http_timeout, diff --git a/daphne/tests/test_h2.py b/daphne/tests/test_h2.py new file mode 100644 index 0000000..ee93a30 --- /dev/null +++ b/daphne/tests/test_h2.py @@ -0,0 +1,83 @@ +from unittest import TestCase +from asgiref.inmemory import ChannelLayer +from twisted.test import proto_helpers + +from ..http2_protocol import H2Factory +from h2.connection import H2Connection +import h2.events + +class TestH2Protocol(TestCase): + """ + Tests that the HTTP protocol class correctly generates and parses messages. + """ + + def setUp(self): + self.channel_layer = ChannelLayer() + self.factory = H2Factory(self.channel_layer) + self.proto = self.factory.buildProtocol(('127.0.0.1', 0)) + self.tr = proto_helpers.StringTransport() + self.proto.makeConnection(self.tr) + + + def assertStartsWith(self, data, prefix): + real_prefix = data[:len(prefix)] + self.assertEqual(real_prefix, prefix) + + def test_basic(self): + """ + Tests basic HTTP parsing + """ + # Send a simple request to the protocol + + conn = H2Connection() + conn.initiate_connection() + #self.tr.write(conn.data_to_send()) + self.proto.dataReceived(conn.data_to_send()) + conn.send_headers(1, [ + (':method', 'GET'), + (':path', '/test/?foo=bar'), + ('user-agent', 'hyper-h2/yo'), + ], end_stream=True) + self.proto.dataReceived(conn.data_to_send()) + + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['http_version'], "2.0") + self.assertEqual(message['method'], "GET") + self.assertEqual(message['scheme'], "http") + self.assertEqual(message['path'], b"/test/") + self.assertEqual(message['query_string'], b"foo=bar") + # self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")]) + self.assertFalse(message.get("body", None)) + self.assertTrue(message['reply_channel']) + + # Send back an example response + self.factory.dispatch_reply( + message['reply_channel'], + { + "status": 201, + "status_text": b"Created", + "content": b"OH HAI", + "headers": [[b"X-Test", b"Boom!"]], + } + ) + # Make sure that comes back right on the protocol + data = self.tr.value() + evs = conn.receive_data(data) + # we should see a ResponseReceived + hasResponse = False + for e in evs : + if isinstance(e, h2.events.ResponseReceived) : + hasResponse = True + headers = dict(e.headers) + self.assertEqual(headers["x-test"],"Boom!") + self.assertEqual(headers[":status"], "201") + + if isinstance(e, h2.events.DataReceived): + self.assertEqual(e.data, b"OH HAI") + self.assertTrue(hasResponse) + + # a DataReceived + # a StreamEnded ?? + + # self.assertEqual(evs) + # self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n") From d259d8e57f0ba0ce92ca93409c3f3a5fd3caee47 Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Mon, 4 Apr 2016 22:18:58 +0200 Subject: [PATCH 2/5] add h2 deps on setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 5a58390..2be72dd 100755 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ setup( 'asgiref>=0.10', 'twisted>=15.5', 'autobahn>=0.12', + 'h2>=2.2' ], entry_points={'console_scripts': [ 'daphne = daphne.cli:CommandLineInterface.entrypoint', From 48889827ea53d69266ea96b4b0e75a4ad9d13ccb Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Mon, 4 Apr 2016 22:19:17 +0200 Subject: [PATCH 3/5] streaming body for both req and response --- daphne/http2_protocol.py | 157 ++++++++++++++++++++++++++++++--------- 1 file changed, 121 insertions(+), 36 deletions(-) diff --git a/daphne/http2_protocol.py b/daphne/http2_protocol.py index 4be6faf..449ed43 100644 --- a/daphne/http2_protocol.py +++ b/daphne/http2_protocol.py @@ -25,19 +25,70 @@ class H2Request(object): self.start_time = time.time() self.reply_channel = reply_channel self.body_channel = body_channel + self.response_started = False + self.headers = {} + self._header_sent = False # have header message been sent to channel layer ? + + def setHeaders(self, headers) : + self.headers = headers + self.body = b"" + + def sendHeaders(self): + + path = self.headers[':path'] + query_string = b"" + if "?" in path: # h2 makes path a unicode + path, query_string = path.encode().split(b"?", 1) + + # clean up ':' prefixed headers + headers_ = {} + for k,v in self.headers.items() : + if not k.startswith(':'): + headers_[k] = v + + # not post : wait for body before sending message + self.protocol.factory.channel_layer.send("http.request", { + "reply_channel": self.reply_channel, + "http_version": "2.0", # \o/ + "scheme": "http", # should be read from env/proxys headers ?? + "method" : self.headers[':method'], + "path" : path, # asgi expects these as bytes + "query_string" : query_string, + "headers": headers_, + "body": self.body, # this is populated on DataReceived event + "client": [self.protocol.transport.getHost().host, + self.protocol.transport.getHost().port], + }) + + self._header_send = True def serverResponse(self, message ): - print(message) - self.protocol.makeResponse(self.stream_id, message) - del self.protocol.factory.reply_protocols[self.reply_channel] + if "status" in message : + assert(not self.response_started) + self.response_started = True + self.protocol.makeResponse(self.stream_id, message) + # only if we are done + else : + assert(self.response_started) + self.protocol.sendData(self.stream_id, + message["content"], + message["more_content"]) - def DataReceived(self, data) : - """ chunk of body """ - self.protocol.factory.channel_layer.send(self.body_channel, { - content: data, - closed: False, # send a True to signal interruption of requests - more_content: False, - }) + if(not message.get("more_content", False)) : + del self.protocol.factory.reply_protocols[self.reply_channel] + + + + def dataReceived(self, data) : + """ chunk of body received """ + if(self._header_sent and self.body_channel) : + self.protocol.factory.channel_layer.send(self.body_channel, { + "content": data, + "closed": False, # send a True to signal interruption of requests + "more_content": False, # we just can't know that .. + }) + else : + print("Barf!") def duration(self): return time.time() - self.start_time @@ -71,11 +122,11 @@ class H2Protocol(Protocol): self.requestReceived(event.headers, event.stream_id) elif isinstance(event, DataReceived): self.dataFrameReceived(event.stream_id, event.data) - #elif isinstance(event, WindowUpdated): - # self.windowUpdated(event) + elif isinstance(event, WindowUpdated): + self.windowUpdated(event) def makeResponse(self, stream_id, message) : - + print("responding", message) response_headers = [ (':status', str(message["status"])), ('server', 'twisted-h2'), @@ -88,9 +139,11 @@ class H2Protocol(Protocol): self.transport.write(self.conn.data_to_send()) # write content .. Chnk this !! - self.conn.send_data(stream_id, message["content"], True) - self.transport.write(self.conn.data_to_send()) - + more_content = message.get('more_content', False) + # that's a twisted deferred, if you don't add a call back, + # this gets discarded + d = self.sendData(stream_id, message["content"], more_content) + d.addErrback(lambda e: print("error in send data", e)) def requestReceived(self, headers, stream_id): @@ -98,36 +151,68 @@ class H2Protocol(Protocol): reply_channel = self.factory.channel_layer.new_channel("http.response!") - # how do we know if there's a pending body ?? - # body_channel = self.factory.channel_layer.new_channel("http.request.body!") - req = H2Request(self, stream_id, reply_channel, None) + body_channel = None + if(headers[':method'] == 'POST'): + body_channel = self.factory.channel_layer.new_channel("http.request.body!") + # body_channel = + req = H2Request(self, stream_id, reply_channel, body_channel) + req.setHeaders(headers) self.requests[stream_id] = req self.factory.reply_protocols[reply_channel] = req - path = headers[':path'] - query_string = b"" - if "?" in path: # h2 makes path a unicode - path, query_string = path.encode().split(b"?", 1) - - self.factory.channel_layer.send("http.request", { - "reply_channel": reply_channel, - "http_version": "2.0", # \o/ - "scheme": "http", # should be read from env/proxys headers ?? - "method" : headers[':method'], - "path" : path, # asgi expects these as bytes - "query_string" : query_string, - "headers": headers, - "body": b"", # this is populated on DataReceived event - "client": [self.transport.getHost().host, self.transport.getHost().port], - }) + # send the request to channel layer, or wait for body + req.sendHeaders() + @inlineCallbacks + def sendData(self, stream_id, data, more_content=False): + # chunks and enqueue data + send_more = True + msg_size = len(data) + offset = 0 + while send_more : + print("waigint for flow control") + while not self.conn.remote_flow_control_window(stream_id) : + # do we have a flow window ? + yield self.wait_for_flow_control(stream_id) + + chunk_size = min(self.conn.remote_flow_control_window(stream_id),READ_CHUNK_SIZE) + + # hopefully, both are bigger than message data + if (msg_size - offset) < chunk_size : + send_more = False + end_chunk = offset + chunk_size + 1 + else : + end_chunk = msg_size + 1 + + chunk = data[offset:end_chunk] + # if more_content, keep request active + done = not ( send_more or more_content) + self.conn.send_data(stream_id, chunk, done) + self.transport.write(self.conn.data_to_send()) + + + def wait_for_flow_control(self, stream_id): + d = Deferred() + self._flow_control_deferreds[stream_id] = d + return d def dataFrameReceived(self, stream_id, data): self.requests[stream_id].dataReceived(data) - + def windowUpdated(self, event): + stream_id = event.stream_id + print("window flow ctrl", stream_id) + if stream_id and stream_id in self._flow_control_deferreds: + d = self._flow_control_deferreds.pop(stream_id) + d.callback(event.delta) + elif not stream_id: + # fire them all.. + for d in self._flow_control_deferreds.values(): + d.callback(event.delta) + self._flow_control_deferreds = {} + return class H2Factory(Factory): From 6e3c69eaf731279382aa39117f79fdb1bfbbf878 Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Tue, 5 Apr 2016 00:13:30 +0200 Subject: [PATCH 4/5] add ssl support (required for any browser use of h2) --- daphne/cli.py | 17 +++++++++++++++++ daphne/server.py | 33 +++++++++++++++++++++++++++------ setup.py | 3 ++- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index 415f547..d075286 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -74,11 +74,25 @@ class CommandLineInterface(object): help="enable HTTP/2" ) + self.parser.add_argument( + '--sslcert', + action="store", + help="path to ssl certificate file" + ) + + self.parser.add_argument( + '--sslkey', + action="store", + help="path to ssl private key file" + ) + + self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', ) + @classmethod def entrypoint(cls): """ @@ -137,4 +151,7 @@ class CommandLineInterface(object): http_timeout=args.http_timeout, ping_interval=args.ping_interval, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, + ssl_certificate=args.sslcert, + ssl_key=args.sslkey + ).run() diff --git a/daphne/server.py b/daphne/server.py index 1188f66..bac05b9 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -1,5 +1,6 @@ import logging -from twisted.internet import reactor +from twisted.internet import reactor, ssl, endpoints +from OpenSSL import crypto from .http_protocol import HTTPFactory @@ -20,6 +21,8 @@ class Server(object): http_timeout=120, websocket_timeout=None, ping_interval=20, + ssl_certificate = None, + ssl_key = None ): self.channel_layer = channel_layer self.host = host @@ -33,7 +36,9 @@ class Server(object): # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.factory_class = factory_class - + self.ssl_certificate = ssl_certificate + self.ssl_key = ssl_key + def run(self): self.factory = self.factory_class( self.channel_layer, @@ -42,10 +47,26 @@ class Server(object): websocket_timeout=self.websocket_timeout, ping_interval=self.ping_interval, ) - if self.unix_socket: - reactor.listenUNIX(self.unix_socket, self.factory) - else: - reactor.listenTCP(self.port, self.factory, interface=self.host) + + if self.ssl_certificate : + with open(self.ssl_certificate, 'r') as f: + cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + with open(self.ssl_key, 'r') as f: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) + + opts = ssl.CertificateOptions( + privateKey= key, + certificate=cert, + acceptableProtocols=[b'h2'] + ) + + endpt = endpoints.SSL4ServerEndpoint(reactor, self.port, opts, backlog=128) + endpt.listen(self.factory) + else : + if self.unix_socket: + reactor.listenUNIX(self.unix_socket, self.factory) + else: + reactor.listenTCP(self.port, self.factory, interface=self.host) reactor.callLater(0, self.backend_reader) reactor.callLater(2, self.timeout_checker) reactor.run(installSignalHandlers=self.signal_handlers) diff --git a/setup.py b/setup.py index 2be72dd..eddd2fb 100755 --- a/setup.py +++ b/setup.py @@ -24,7 +24,8 @@ setup( 'asgiref>=0.10', 'twisted>=15.5', 'autobahn>=0.12', - 'h2>=2.2' + 'h2>=2.2', + 'pyOpenSSL' # optionnal ?? ], entry_points={'console_scripts': [ 'daphne = daphne.cli:CommandLineInterface.entrypoint', From 0c9cd481f624f7bc6d7581c6d2295daea0d4db02 Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Tue, 5 Apr 2016 21:26:44 +0200 Subject: [PATCH 5/5] make travis install h2, and optionnal import PyOpenSSL --- .travis.yml | 2 +- daphne/server.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6f34248..3c6acac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,5 +5,5 @@ python: - "3.5" install: - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install unittest2; fi - - pip install asgiref twisted autobahn + - pip install asgiref twisted autobahn h2 script: if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then python -m unittest2; else python -m unittest; fi diff --git a/daphne/server.py b/daphne/server.py index bac05b9..7eb321e 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -1,6 +1,5 @@ import logging from twisted.internet import reactor, ssl, endpoints -from OpenSSL import crypto from .http_protocol import HTTPFactory @@ -49,6 +48,9 @@ class Server(object): ) if self.ssl_certificate : + + from OpenSSL import crypto + with open(self.ssl_certificate, 'r') as f: cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) with open(self.ssl_key, 'r') as f: