This commit is contained in:
Florent D'halluin 2016-04-05 19:18:59 +00:00
commit 63feaf20ce
6 changed files with 427 additions and 7 deletions

View File

@ -5,5 +5,5 @@ python:
- "3.5"
install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install unittest2; fi
- pip install asgiref twisted autobahn
- pip install asgiref twisted autobahn h2
script: if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then python -m unittest2; else python -m unittest; fi

View File

@ -5,6 +5,8 @@ import importlib
from .server import Server
from .access import AccessLogGenerator
from .http_protocol import HTTPFactory
from .http2_protocol import H2Factory
logger = logging.getLogger(__name__)
@ -66,11 +68,31 @@ class CommandLineInterface(object):
help='The number of seconds a WebSocket must be idle before a keepalive ping is sent',
default=20,
)
self.parser.add_argument(
'--h2',
action='store_true',
help="enable HTTP/2"
)
self.parser.add_argument(
'--sslcert',
action="store",
help="path to ssl certificate file"
)
self.parser.add_argument(
'--sslkey',
action="store",
help="path to ssl private key file"
)
self.parser.add_argument(
'channel_layer',
help='The ASGI channel layer instance to use as path.to.module:instance.path',
)
@classmethod
def entrypoint(cls):
"""
@ -115,12 +137,21 @@ class CommandLineInterface(object):
(args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)),
args.channel_layer,
)
if args.h2 :
factory_class = H2Factory
else :
factory_class = HTTPFactory
Server(
channel_layer=channel_layer,
factory_class=factory_class,
host=args.host,
port=args.port,
unix_socket=args.unix_socket,
http_timeout=args.http_timeout,
ping_interval=args.ping_interval,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
ssl_certificate=args.sslcert,
ssl_key=args.sslkey
).run()

279
daphne/http2_protocol.py Normal file
View File

@ -0,0 +1,279 @@
# -*- coding: utf-8 -*-
import functools
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.protocol import Protocol, Factory
from twisted.internet import endpoints
from h2.connection import H2Connection
from h2.events import (
RequestReceived, DataReceived, WindowUpdated
)
import time
def close_file(file, d):
file.close()
READ_CHUNK_SIZE = 8192
class H2Request(object):
def __init__(self, protocol, id, reply_channel, body_channel=None) :
self.protocol = protocol
self.stream_id = id
self.start_time = time.time()
self.reply_channel = reply_channel
self.body_channel = body_channel
self.response_started = False
self.headers = {}
self._header_sent = False # have header message been sent to channel layer ?
def setHeaders(self, headers) :
self.headers = headers
self.body = b""
def sendHeaders(self):
path = self.headers[':path']
query_string = b""
if "?" in path: # h2 makes path a unicode
path, query_string = path.encode().split(b"?", 1)
# clean up ':' prefixed headers
headers_ = {}
for k,v in self.headers.items() :
if not k.startswith(':'):
headers_[k] = v
# not post : wait for body before sending message
self.protocol.factory.channel_layer.send("http.request", {
"reply_channel": self.reply_channel,
"http_version": "2.0", # \o/
"scheme": "http", # should be read from env/proxys headers ??
"method" : self.headers[':method'],
"path" : path, # asgi expects these as bytes
"query_string" : query_string,
"headers": headers_,
"body": self.body, # this is populated on DataReceived event
"client": [self.protocol.transport.getHost().host,
self.protocol.transport.getHost().port],
})
self._header_send = True
def serverResponse(self, message ):
if "status" in message :
assert(not self.response_started)
self.response_started = True
self.protocol.makeResponse(self.stream_id, message)
# only if we are done
else :
assert(self.response_started)
self.protocol.sendData(self.stream_id,
message["content"],
message["more_content"])
if(not message.get("more_content", False)) :
del self.protocol.factory.reply_protocols[self.reply_channel]
def dataReceived(self, data) :
""" chunk of body received """
if(self._header_sent and self.body_channel) :
self.protocol.factory.channel_layer.send(self.body_channel, {
"content": data,
"closed": False, # send a True to signal interruption of requests
"more_content": False, # we just can't know that ..
})
else :
print("Barf!")
def duration(self):
return time.time() - self.start_time
def basic_error(self):
pass
class H2Protocol(Protocol):
def __init__(self, factory):
self.conn = H2Connection(client_side=False)
self.factory = factory
self.known_proto = None
#self.root = root
self.requests = {} # ongoing requests
self._flow_control_deferreds = {}
def connectionMade(self):
self.conn.initiate_connection()
self.transport.write(self.conn.data_to_send())
def dataReceived(self, data):
if not self.known_proto:
self.known_proto = True
events = self.conn.receive_data(data)
if self.conn.data_to_send:
self.transport.write(self.conn.data_to_send())
for event in events:
if isinstance(event, RequestReceived):
self.requestReceived(event.headers, event.stream_id)
elif isinstance(event, DataReceived):
self.dataFrameReceived(event.stream_id, event.data)
elif isinstance(event, WindowUpdated):
self.windowUpdated(event)
def makeResponse(self, stream_id, message) :
print("responding", message)
response_headers = [
(':status', str(message["status"])),
('server', 'twisted-h2'),
("status_text", message.get("status_text", "")),
]
for header, value in message.get("headers", []) :
response_headers.append((header, value))
self.conn.send_headers(stream_id, response_headers)
self.transport.write(self.conn.data_to_send())
# write content .. Chnk this !!
more_content = message.get('more_content', False)
# that's a twisted deferred, if you don't add a call back,
# this gets discarded
d = self.sendData(stream_id, message["content"], more_content)
d.addErrback(lambda e: print("error in send data", e))
def requestReceived(self, headers, stream_id):
headers = dict(headers) # Invalid conversion, fix later.
reply_channel = self.factory.channel_layer.new_channel("http.response!")
body_channel = None
if(headers[':method'] == 'POST'):
body_channel = self.factory.channel_layer.new_channel("http.request.body!")
# body_channel =
req = H2Request(self, stream_id, reply_channel, body_channel)
req.setHeaders(headers)
self.requests[stream_id] = req
self.factory.reply_protocols[reply_channel] = req
# send the request to channel layer, or wait for body
req.sendHeaders()
@inlineCallbacks
def sendData(self, stream_id, data, more_content=False):
# chunks and enqueue data
send_more = True
msg_size = len(data)
offset = 0
while send_more :
print("waigint for flow control")
while not self.conn.remote_flow_control_window(stream_id) :
# do we have a flow window ?
yield self.wait_for_flow_control(stream_id)
chunk_size = min(self.conn.remote_flow_control_window(stream_id),READ_CHUNK_SIZE)
# hopefully, both are bigger than message data
if (msg_size - offset) < chunk_size :
send_more = False
end_chunk = offset + chunk_size + 1
else :
end_chunk = msg_size + 1
chunk = data[offset:end_chunk]
# if more_content, keep request active
done = not ( send_more or more_content)
self.conn.send_data(stream_id, chunk, done)
self.transport.write(self.conn.data_to_send())
def wait_for_flow_control(self, stream_id):
d = Deferred()
self._flow_control_deferreds[stream_id] = d
return d
def dataFrameReceived(self, stream_id, data):
self.requests[stream_id].dataReceived(data)
def windowUpdated(self, event):
stream_id = event.stream_id
print("window flow ctrl", stream_id)
if stream_id and stream_id in self._flow_control_deferreds:
d = self._flow_control_deferreds.pop(stream_id)
d.callback(event.delta)
elif not stream_id:
# fire them all..
for d in self._flow_control_deferreds.values():
d.callback(event.delta)
self._flow_control_deferreds = {}
return
class H2Factory(Factory):
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20):
self.channel_layer = channel_layer
self.action_logger = action_logger
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.ping_interval = ping_interval
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
# Make a factory for WebSocket protocols
# self.ws_factory = WebSocketFactory(self)
# self.ws_factory.protocol = WebSocketProtocol
# self.ws_factory.reply_protocols = self.reply_protocols
def buildProtocol(self, addr):
return H2Protocol(self)
# copy pasta from http_protocol
def dispatch_reply(self, channel, message):
if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request):
self.reply_protocols[channel].serverResponse(message)
# elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
# if message.get("bytes", None):
# self.reply_protocols[channel].serverSend(message["bytes"], True)
# if message.get("text", None):
# self.reply_protocols[channel].serverSend(message["text"], False)
# if message.get("close", False):
# self.reply_protocols[channel].serverClose()
else:
raise ValueError("Cannot dispatch message on channel %r" % channel)
# copy pasta from http protocol
def reply_channels(self):
return self.reply_protocols.keys()
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)
def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've
taken too long (and so their message is probably expired)
"""
for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, H2Request) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking and keepalive ping sending
#elif isinstance(protocol, WebSocketProtocol):
# Timeout check
# if protocol.duration() > self.websocket_timeout:
# protocol.serverClose()
# Ping check
# else:
# protocol.check_ping()

View File

@ -1,5 +1,5 @@
import logging
from twisted.internet import reactor
from twisted.internet import reactor, ssl, endpoints
from .http_protocol import HTTPFactory
@ -11,6 +11,7 @@ class Server(object):
def __init__(
self,
channel_layer,
factory_class, # that's a factory factory ... meeeh
host="127.0.0.1",
port=8000,
unix_socket=None,
@ -19,6 +20,8 @@ class Server(object):
http_timeout=120,
websocket_timeout=None,
ping_interval=20,
ssl_certificate = None,
ssl_key = None
):
self.channel_layer = channel_layer
self.host = host
@ -31,19 +34,41 @@ class Server(object):
# If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.factory_class = factory_class
self.ssl_certificate = ssl_certificate
self.ssl_key = ssl_key
def run(self):
self.factory = HTTPFactory(
self.factory = self.factory_class(
self.channel_layer,
self.action_logger,
timeout=self.http_timeout,
websocket_timeout=self.websocket_timeout,
ping_interval=self.ping_interval,
)
if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory)
else:
reactor.listenTCP(self.port, self.factory, interface=self.host)
if self.ssl_certificate :
from OpenSSL import crypto
with open(self.ssl_certificate, 'r') as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(self.ssl_key, 'r') as f:
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
opts = ssl.CertificateOptions(
privateKey= key,
certificate=cert,
acceptableProtocols=[b'h2']
)
endpt = endpoints.SSL4ServerEndpoint(reactor, self.port, opts, backlog=128)
endpt.listen(self.factory)
else :
if self.unix_socket:
reactor.listenUNIX(self.unix_socket, self.factory)
else:
reactor.listenTCP(self.port, self.factory, interface=self.host)
reactor.callLater(0, self.backend_reader)
reactor.callLater(2, self.timeout_checker)
reactor.run(installSignalHandlers=self.signal_handlers)

83
daphne/tests/test_h2.py Normal file
View File

@ -0,0 +1,83 @@
from unittest import TestCase
from asgiref.inmemory import ChannelLayer
from twisted.test import proto_helpers
from ..http2_protocol import H2Factory
from h2.connection import H2Connection
import h2.events
class TestH2Protocol(TestCase):
"""
Tests that the HTTP protocol class correctly generates and parses messages.
"""
def setUp(self):
self.channel_layer = ChannelLayer()
self.factory = H2Factory(self.channel_layer)
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
def assertStartsWith(self, data, prefix):
real_prefix = data[:len(prefix)]
self.assertEqual(real_prefix, prefix)
def test_basic(self):
"""
Tests basic HTTP parsing
"""
# Send a simple request to the protocol
conn = H2Connection()
conn.initiate_connection()
#self.tr.write(conn.data_to_send())
self.proto.dataReceived(conn.data_to_send())
conn.send_headers(1, [
(':method', 'GET'),
(':path', '/test/?foo=bar'),
('user-agent', 'hyper-h2/yo'),
], end_stream=True)
self.proto.dataReceived(conn.data_to_send())
_, message = self.channel_layer.receive_many(["http.request"])
self.assertEqual(message['http_version'], "2.0")
self.assertEqual(message['method'], "GET")
self.assertEqual(message['scheme'], "http")
self.assertEqual(message['path'], b"/test/")
self.assertEqual(message['query_string'], b"foo=bar")
# self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")])
self.assertFalse(message.get("body", None))
self.assertTrue(message['reply_channel'])
# Send back an example response
self.factory.dispatch_reply(
message['reply_channel'],
{
"status": 201,
"status_text": b"Created",
"content": b"OH HAI",
"headers": [[b"X-Test", b"Boom!"]],
}
)
# Make sure that comes back right on the protocol
data = self.tr.value()
evs = conn.receive_data(data)
# we should see a ResponseReceived
hasResponse = False
for e in evs :
if isinstance(e, h2.events.ResponseReceived) :
hasResponse = True
headers = dict(e.headers)
self.assertEqual(headers["x-test"],"Boom!")
self.assertEqual(headers[":status"], "201")
if isinstance(e, h2.events.DataReceived):
self.assertEqual(e.data, b"OH HAI")
self.assertTrue(hasResponse)
# a DataReceived
# a StreamEnded ??
# self.assertEqual(evs)
# self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n")

View File

@ -24,6 +24,8 @@ setup(
'asgiref>=0.10',
'twisted>=15.5',
'autobahn>=0.12',
'h2>=2.2',
'pyOpenSSL' # optionnal ??
],
entry_points={'console_scripts': [
'daphne = daphne.cli:CommandLineInterface.entrypoint',