initial proof of concept, works on simple GET request (no streaming body, no streaming reply, no SSL )

This commit is contained in:
Florent D'halluin 2016-04-03 16:17:04 +02:00
parent 7b13995dee
commit 854f25b7bd
4 changed files with 295 additions and 2 deletions

View File

@ -5,6 +5,8 @@ import importlib
from .server import Server from .server import Server
from .access import AccessLogGenerator from .access import AccessLogGenerator
from .http_protocol import HTTPFactory
from .http2_protocol import H2Factory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,6 +68,12 @@ class CommandLineInterface(object):
help='The number of seconds a WebSocket must be idle before a keepalive ping is sent', help='The number of seconds a WebSocket must be idle before a keepalive ping is sent',
default=20, default=20,
) )
self.parser.add_argument(
'--h2',
action='store_true',
help="enable HTTP/2"
)
self.parser.add_argument( self.parser.add_argument(
'channel_layer', 'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path', help='The ASGI channel layer instance to use as path.to.module:instance.path',
@ -115,8 +123,14 @@ class CommandLineInterface(object):
(args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)), (args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)),
args.channel_layer, args.channel_layer,
) )
if args.h2 :
factory_class = H2Factory
else :
factory_class = HTTPFactory
Server( Server(
channel_layer=channel_layer, channel_layer=channel_layer,
factory_class=factory_class,
host=args.host, host=args.host,
port=args.port, port=args.port,
unix_socket=args.unix_socket, unix_socket=args.unix_socket,

194
daphne/http2_protocol.py Normal file
View File

@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
import functools
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.protocol import Protocol, Factory
from twisted.internet import endpoints
from h2.connection import H2Connection
from h2.events import (
RequestReceived, DataReceived, WindowUpdated
)
import time
def close_file(file, d):
file.close()
READ_CHUNK_SIZE = 8192
class H2Request(object):
def __init__(self, protocol, id, reply_channel, body_channel=None) :
self.protocol = protocol
self.stream_id = id
self.start_time = time.time()
self.reply_channel = reply_channel
self.body_channel = body_channel
def serverResponse(self, message ):
print(message)
self.protocol.makeResponse(self.stream_id, message)
del self.protocol.factory.reply_protocols[self.reply_channel]
def DataReceived(self, data) :
""" chunk of body """
self.protocol.factory.channel_layer.send(self.body_channel, {
content: data,
closed: False, # send a True to signal interruption of requests
more_content: False,
})
def duration(self):
return time.time() - self.start_time
def basic_error(self):
pass
class H2Protocol(Protocol):
def __init__(self, factory):
self.conn = H2Connection(client_side=False)
self.factory = factory
self.known_proto = None
#self.root = root
self.requests = {} # ongoing requests
self._flow_control_deferreds = {}
def connectionMade(self):
self.conn.initiate_connection()
self.transport.write(self.conn.data_to_send())
def dataReceived(self, data):
if not self.known_proto:
self.known_proto = True
events = self.conn.receive_data(data)
if self.conn.data_to_send:
self.transport.write(self.conn.data_to_send())
for event in events:
if isinstance(event, RequestReceived):
self.requestReceived(event.headers, event.stream_id)
elif isinstance(event, DataReceived):
self.dataFrameReceived(event.stream_id, event.data)
#elif isinstance(event, WindowUpdated):
# self.windowUpdated(event)
def makeResponse(self, stream_id, message) :
response_headers = [
(':status', str(message["status"])),
('server', 'twisted-h2'),
("status_text", message.get("status_text", "")),
]
for header, value in message.get("headers", []) :
response_headers.append((header, value))
self.conn.send_headers(stream_id, response_headers)
self.transport.write(self.conn.data_to_send())
# write content .. Chnk this !!
self.conn.send_data(stream_id, message["content"], True)
self.transport.write(self.conn.data_to_send())
def requestReceived(self, headers, stream_id):
headers = dict(headers) # Invalid conversion, fix later.
reply_channel = self.factory.channel_layer.new_channel("http.response!")
# how do we know if there's a pending body ??
# body_channel = self.factory.channel_layer.new_channel("http.request.body!")
req = H2Request(self, stream_id, reply_channel, None)
self.requests[stream_id] = req
self.factory.reply_protocols[reply_channel] = req
path = headers[':path']
query_string = b""
if "?" in path: # h2 makes path a unicode
path, query_string = path.encode().split(b"?", 1)
self.factory.channel_layer.send("http.request", {
"reply_channel": reply_channel,
"http_version": "2.0", # \o/
"scheme": "http", # should be read from env/proxys headers ??
"method" : headers[':method'],
"path" : path, # asgi expects these as bytes
"query_string" : query_string,
"headers": headers,
"body": b"", # this is populated on DataReceived event
"client": [self.transport.getHost().host, self.transport.getHost().port],
})
def dataFrameReceived(self, stream_id, data):
self.requests[stream_id].dataReceived(data)
class H2Factory(Factory):
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20):
self.channel_layer = channel_layer
self.action_logger = action_logger
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.ping_interval = ping_interval
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
# Make a factory for WebSocket protocols
# self.ws_factory = WebSocketFactory(self)
# self.ws_factory.protocol = WebSocketProtocol
# self.ws_factory.reply_protocols = self.reply_protocols
def buildProtocol(self, addr):
return H2Protocol(self)
# copy pasta from http_protocol
def dispatch_reply(self, channel, message):
if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request):
self.reply_protocols[channel].serverResponse(message)
# elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
# if message.get("bytes", None):
# self.reply_protocols[channel].serverSend(message["bytes"], True)
# if message.get("text", None):
# self.reply_protocols[channel].serverSend(message["text"], False)
# if message.get("close", False):
# self.reply_protocols[channel].serverClose()
else:
raise ValueError("Cannot dispatch message on channel %r" % channel)
# copy pasta from http protocol
def reply_channels(self):
return self.reply_protocols.keys()
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)
def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've
taken too long (and so their message is probably expired)
"""
for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, H2Request) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking and keepalive ping sending
#elif isinstance(protocol, WebSocketProtocol):
# Timeout check
# if protocol.duration() > self.websocket_timeout:
# protocol.serverClose()
# Ping check
# else:
# protocol.check_ping()

View File

@ -11,6 +11,7 @@ class Server(object):
def __init__( def __init__(
self, self,
channel_layer, channel_layer,
factory_class, # that's a factory factory ... meeeh
host="127.0.0.1", host="127.0.0.1",
port=8000, port=8000,
unix_socket=None, unix_socket=None,
@ -31,9 +32,10 @@ class Server(object):
# If they did not provide a websocket timeout, default it to the # If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not. # channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.factory_class = factory_class
def run(self): def run(self):
self.factory = HTTPFactory( self.factory = self.factory_class(
self.channel_layer, self.channel_layer,
self.action_logger, self.action_logger,
timeout=self.http_timeout, timeout=self.http_timeout,

83
daphne/tests/test_h2.py Normal file
View File

@ -0,0 +1,83 @@
from unittest import TestCase
from asgiref.inmemory import ChannelLayer
from twisted.test import proto_helpers
from ..http2_protocol import H2Factory
from h2.connection import H2Connection
import h2.events
class TestH2Protocol(TestCase):
"""
Tests that the HTTP protocol class correctly generates and parses messages.
"""
def setUp(self):
self.channel_layer = ChannelLayer()
self.factory = H2Factory(self.channel_layer)
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
def assertStartsWith(self, data, prefix):
real_prefix = data[:len(prefix)]
self.assertEqual(real_prefix, prefix)
def test_basic(self):
"""
Tests basic HTTP parsing
"""
# Send a simple request to the protocol
conn = H2Connection()
conn.initiate_connection()
#self.tr.write(conn.data_to_send())
self.proto.dataReceived(conn.data_to_send())
conn.send_headers(1, [
(':method', 'GET'),
(':path', '/test/?foo=bar'),
('user-agent', 'hyper-h2/yo'),
], end_stream=True)
self.proto.dataReceived(conn.data_to_send())
_, message = self.channel_layer.receive_many(["http.request"])
self.assertEqual(message['http_version'], "2.0")
self.assertEqual(message['method'], "GET")
self.assertEqual(message['scheme'], "http")
self.assertEqual(message['path'], b"/test/")
self.assertEqual(message['query_string'], b"foo=bar")
# self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")])
self.assertFalse(message.get("body", None))
self.assertTrue(message['reply_channel'])
# Send back an example response
self.factory.dispatch_reply(
message['reply_channel'],
{
"status": 201,
"status_text": b"Created",
"content": b"OH HAI",
"headers": [[b"X-Test", b"Boom!"]],
}
)
# Make sure that comes back right on the protocol
data = self.tr.value()
evs = conn.receive_data(data)
# we should see a ResponseReceived
hasResponse = False
for e in evs :
if isinstance(e, h2.events.ResponseReceived) :
hasResponse = True
headers = dict(e.headers)
self.assertEqual(headers["x-test"],"Boom!")
self.assertEqual(headers[":status"], "201")
if isinstance(e, h2.events.DataReceived):
self.assertEqual(e.data, b"OH HAI")
self.assertTrue(hasResponse)
# a DataReceived
# a StreamEnded ??
# self.assertEqual(evs)
# self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n")