diff --git a/daphne/cli.py b/daphne/cli.py index 8763c87..415f547 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -5,6 +5,8 @@ import importlib from .server import Server from .access import AccessLogGenerator +from .http_protocol import HTTPFactory +from .http2_protocol import H2Factory logger = logging.getLogger(__name__) @@ -66,6 +68,12 @@ class CommandLineInterface(object): help='The number of seconds a WebSocket must be idle before a keepalive ping is sent', default=20, ) + self.parser.add_argument( + '--h2', + action='store_true', + help="enable HTTP/2" + ) + self.parser.add_argument( 'channel_layer', help='The ASGI channel layer instance to use as path.to.module:instance.path', @@ -115,8 +123,14 @@ class CommandLineInterface(object): (args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)), args.channel_layer, ) + if args.h2 : + factory_class = H2Factory + else : + factory_class = HTTPFactory + Server( channel_layer=channel_layer, + factory_class=factory_class, host=args.host, port=args.port, unix_socket=args.unix_socket, diff --git a/daphne/http2_protocol.py b/daphne/http2_protocol.py new file mode 100644 index 0000000..4be6faf --- /dev/null +++ b/daphne/http2_protocol.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +import functools + +from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.protocol import Protocol, Factory +from twisted.internet import endpoints + +from h2.connection import H2Connection +from h2.events import ( + RequestReceived, DataReceived, WindowUpdated +) +import time + + +def close_file(file, d): + file.close() + + +READ_CHUNK_SIZE = 8192 + +class H2Request(object): + def __init__(self, protocol, id, reply_channel, body_channel=None) : + self.protocol = protocol + self.stream_id = id + self.start_time = time.time() + self.reply_channel = reply_channel + self.body_channel = body_channel + + def serverResponse(self, message ): + print(message) + self.protocol.makeResponse(self.stream_id, message) + del self.protocol.factory.reply_protocols[self.reply_channel] + + def DataReceived(self, data) : + """ chunk of body """ + self.protocol.factory.channel_layer.send(self.body_channel, { + content: data, + closed: False, # send a True to signal interruption of requests + more_content: False, + }) + + def duration(self): + return time.time() - self.start_time + + def basic_error(self): + pass + +class H2Protocol(Protocol): + def __init__(self, factory): + self.conn = H2Connection(client_side=False) + self.factory = factory + self.known_proto = None + #self.root = root + self.requests = {} # ongoing requests + self._flow_control_deferreds = {} + + def connectionMade(self): + self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + + def dataReceived(self, data): + if not self.known_proto: + self.known_proto = True + + events = self.conn.receive_data(data) + if self.conn.data_to_send: + self.transport.write(self.conn.data_to_send()) + + for event in events: + if isinstance(event, RequestReceived): + self.requestReceived(event.headers, event.stream_id) + elif isinstance(event, DataReceived): + self.dataFrameReceived(event.stream_id, event.data) + #elif isinstance(event, WindowUpdated): + # self.windowUpdated(event) + + def makeResponse(self, stream_id, message) : + + response_headers = [ + (':status', str(message["status"])), + ('server', 'twisted-h2'), + ("status_text", message.get("status_text", "")), + ] + for header, value in message.get("headers", []) : + response_headers.append((header, value)) + + self.conn.send_headers(stream_id, response_headers) + self.transport.write(self.conn.data_to_send()) + + # write content .. Chnk this !! + self.conn.send_data(stream_id, message["content"], True) + self.transport.write(self.conn.data_to_send()) + + + + def requestReceived(self, headers, stream_id): + headers = dict(headers) # Invalid conversion, fix later. + + reply_channel = self.factory.channel_layer.new_channel("http.response!") + + # how do we know if there's a pending body ?? + # body_channel = self.factory.channel_layer.new_channel("http.request.body!") + req = H2Request(self, stream_id, reply_channel, None) + + self.requests[stream_id] = req + self.factory.reply_protocols[reply_channel] = req + + path = headers[':path'] + query_string = b"" + if "?" in path: # h2 makes path a unicode + path, query_string = path.encode().split(b"?", 1) + + self.factory.channel_layer.send("http.request", { + "reply_channel": reply_channel, + "http_version": "2.0", # \o/ + "scheme": "http", # should be read from env/proxys headers ?? + "method" : headers[':method'], + "path" : path, # asgi expects these as bytes + "query_string" : query_string, + "headers": headers, + "body": b"", # this is populated on DataReceived event + "client": [self.transport.getHost().host, self.transport.getHost().port], + }) + + + + def dataFrameReceived(self, stream_id, data): + self.requests[stream_id].dataReceived(data) + + + + +class H2Factory(Factory): + + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20): + self.channel_layer = channel_layer + self.action_logger = action_logger + self.timeout = timeout + self.websocket_timeout = websocket_timeout + self.ping_interval = ping_interval + # We track all sub-protocols for response channel mapping + self.reply_protocols = {} + # Make a factory for WebSocket protocols + # self.ws_factory = WebSocketFactory(self) + # self.ws_factory.protocol = WebSocketProtocol + # self.ws_factory.reply_protocols = self.reply_protocols + + def buildProtocol(self, addr): + return H2Protocol(self) + + + # copy pasta from http_protocol + def dispatch_reply(self, channel, message): + if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request): + self.reply_protocols[channel].serverResponse(message) + # elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol): + # if message.get("bytes", None): + # self.reply_protocols[channel].serverSend(message["bytes"], True) + # if message.get("text", None): + # self.reply_protocols[channel].serverSend(message["text"], False) + # if message.get("close", False): + # self.reply_protocols[channel].serverClose() + else: + raise ValueError("Cannot dispatch message on channel %r" % channel) + + # copy pasta from http protocol + def reply_channels(self): + return self.reply_protocols.keys() + + + def log_action(self, protocol, action, details): + """ + Dispatches to any registered action logger, if there is one. + """ + if self.action_logger: + self.action_logger(protocol, action, details) + + def check_timeouts(self): + """ + Runs through all HTTP protocol instances and times them out if they've + taken too long (and so their message is probably expired) + """ + for protocol in list(self.reply_protocols.values()): + # Web timeout checking + if isinstance(protocol, H2Request) and protocol.duration() > self.timeout: + protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.") + # WebSocket timeout checking and keepalive ping sending + #elif isinstance(protocol, WebSocketProtocol): + # Timeout check + # if protocol.duration() > self.websocket_timeout: + # protocol.serverClose() + # Ping check + # else: + # protocol.check_ping() diff --git a/daphne/server.py b/daphne/server.py index 90bc6f2..1188f66 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -11,6 +11,7 @@ class Server(object): def __init__( self, channel_layer, + factory_class, # that's a factory factory ... meeeh host="127.0.0.1", port=8000, unix_socket=None, @@ -31,9 +32,10 @@ class Server(object): # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) - + self.factory_class = factory_class + def run(self): - self.factory = HTTPFactory( + self.factory = self.factory_class( self.channel_layer, self.action_logger, timeout=self.http_timeout, diff --git a/daphne/tests/test_h2.py b/daphne/tests/test_h2.py new file mode 100644 index 0000000..ee93a30 --- /dev/null +++ b/daphne/tests/test_h2.py @@ -0,0 +1,83 @@ +from unittest import TestCase +from asgiref.inmemory import ChannelLayer +from twisted.test import proto_helpers + +from ..http2_protocol import H2Factory +from h2.connection import H2Connection +import h2.events + +class TestH2Protocol(TestCase): + """ + Tests that the HTTP protocol class correctly generates and parses messages. + """ + + def setUp(self): + self.channel_layer = ChannelLayer() + self.factory = H2Factory(self.channel_layer) + self.proto = self.factory.buildProtocol(('127.0.0.1', 0)) + self.tr = proto_helpers.StringTransport() + self.proto.makeConnection(self.tr) + + + def assertStartsWith(self, data, prefix): + real_prefix = data[:len(prefix)] + self.assertEqual(real_prefix, prefix) + + def test_basic(self): + """ + Tests basic HTTP parsing + """ + # Send a simple request to the protocol + + conn = H2Connection() + conn.initiate_connection() + #self.tr.write(conn.data_to_send()) + self.proto.dataReceived(conn.data_to_send()) + conn.send_headers(1, [ + (':method', 'GET'), + (':path', '/test/?foo=bar'), + ('user-agent', 'hyper-h2/yo'), + ], end_stream=True) + self.proto.dataReceived(conn.data_to_send()) + + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['http_version'], "2.0") + self.assertEqual(message['method'], "GET") + self.assertEqual(message['scheme'], "http") + self.assertEqual(message['path'], b"/test/") + self.assertEqual(message['query_string'], b"foo=bar") + # self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")]) + self.assertFalse(message.get("body", None)) + self.assertTrue(message['reply_channel']) + + # Send back an example response + self.factory.dispatch_reply( + message['reply_channel'], + { + "status": 201, + "status_text": b"Created", + "content": b"OH HAI", + "headers": [[b"X-Test", b"Boom!"]], + } + ) + # Make sure that comes back right on the protocol + data = self.tr.value() + evs = conn.receive_data(data) + # we should see a ResponseReceived + hasResponse = False + for e in evs : + if isinstance(e, h2.events.ResponseReceived) : + hasResponse = True + headers = dict(e.headers) + self.assertEqual(headers["x-test"],"Boom!") + self.assertEqual(headers[":status"], "201") + + if isinstance(e, h2.events.DataReceived): + self.assertEqual(e.data, b"OH HAI") + self.assertTrue(hasResponse) + + # a DataReceived + # a StreamEnded ?? + + # self.assertEqual(evs) + # self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n")