mirror of
https://github.com/django/daphne.git
synced 2025-07-13 09:22:17 +03:00
initial proof of concept, works on simple GET request (no streaming body, no streaming reply, no SSL )
This commit is contained in:
parent
7b13995dee
commit
854f25b7bd
|
@ -5,6 +5,8 @@ import importlib
|
|||
from .server import Server
|
||||
from .access import AccessLogGenerator
|
||||
|
||||
from .http_protocol import HTTPFactory
|
||||
from .http2_protocol import H2Factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -66,6 +68,12 @@ class CommandLineInterface(object):
|
|||
help='The number of seconds a WebSocket must be idle before a keepalive ping is sent',
|
||||
default=20,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'--h2',
|
||||
action='store_true',
|
||||
help="enable HTTP/2"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
'channel_layer',
|
||||
help='The ASGI channel layer instance to use as path.to.module:instance.path',
|
||||
|
@ -115,8 +123,14 @@ class CommandLineInterface(object):
|
|||
(args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)),
|
||||
args.channel_layer,
|
||||
)
|
||||
if args.h2 :
|
||||
factory_class = H2Factory
|
||||
else :
|
||||
factory_class = HTTPFactory
|
||||
|
||||
Server(
|
||||
channel_layer=channel_layer,
|
||||
factory_class=factory_class,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
unix_socket=args.unix_socket,
|
||||
|
|
194
daphne/http2_protocol.py
Normal file
194
daphne/http2_protocol.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import functools
|
||||
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks
|
||||
from twisted.internet.protocol import Protocol, Factory
|
||||
from twisted.internet import endpoints
|
||||
|
||||
from h2.connection import H2Connection
|
||||
from h2.events import (
|
||||
RequestReceived, DataReceived, WindowUpdated
|
||||
)
|
||||
import time
|
||||
|
||||
|
||||
def close_file(file, d):
|
||||
file.close()
|
||||
|
||||
|
||||
READ_CHUNK_SIZE = 8192
|
||||
|
||||
class H2Request(object):
|
||||
def __init__(self, protocol, id, reply_channel, body_channel=None) :
|
||||
self.protocol = protocol
|
||||
self.stream_id = id
|
||||
self.start_time = time.time()
|
||||
self.reply_channel = reply_channel
|
||||
self.body_channel = body_channel
|
||||
|
||||
def serverResponse(self, message ):
|
||||
print(message)
|
||||
self.protocol.makeResponse(self.stream_id, message)
|
||||
del self.protocol.factory.reply_protocols[self.reply_channel]
|
||||
|
||||
def DataReceived(self, data) :
|
||||
""" chunk of body """
|
||||
self.protocol.factory.channel_layer.send(self.body_channel, {
|
||||
content: data,
|
||||
closed: False, # send a True to signal interruption of requests
|
||||
more_content: False,
|
||||
})
|
||||
|
||||
def duration(self):
|
||||
return time.time() - self.start_time
|
||||
|
||||
def basic_error(self):
|
||||
pass
|
||||
|
||||
class H2Protocol(Protocol):
|
||||
def __init__(self, factory):
|
||||
self.conn = H2Connection(client_side=False)
|
||||
self.factory = factory
|
||||
self.known_proto = None
|
||||
#self.root = root
|
||||
self.requests = {} # ongoing requests
|
||||
self._flow_control_deferreds = {}
|
||||
|
||||
def connectionMade(self):
|
||||
self.conn.initiate_connection()
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
def dataReceived(self, data):
|
||||
if not self.known_proto:
|
||||
self.known_proto = True
|
||||
|
||||
events = self.conn.receive_data(data)
|
||||
if self.conn.data_to_send:
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, RequestReceived):
|
||||
self.requestReceived(event.headers, event.stream_id)
|
||||
elif isinstance(event, DataReceived):
|
||||
self.dataFrameReceived(event.stream_id, event.data)
|
||||
#elif isinstance(event, WindowUpdated):
|
||||
# self.windowUpdated(event)
|
||||
|
||||
def makeResponse(self, stream_id, message) :
|
||||
|
||||
response_headers = [
|
||||
(':status', str(message["status"])),
|
||||
('server', 'twisted-h2'),
|
||||
("status_text", message.get("status_text", "")),
|
||||
]
|
||||
for header, value in message.get("headers", []) :
|
||||
response_headers.append((header, value))
|
||||
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
# write content .. Chnk this !!
|
||||
self.conn.send_data(stream_id, message["content"], True)
|
||||
self.transport.write(self.conn.data_to_send())
|
||||
|
||||
|
||||
|
||||
def requestReceived(self, headers, stream_id):
|
||||
headers = dict(headers) # Invalid conversion, fix later.
|
||||
|
||||
reply_channel = self.factory.channel_layer.new_channel("http.response!")
|
||||
|
||||
# how do we know if there's a pending body ??
|
||||
# body_channel = self.factory.channel_layer.new_channel("http.request.body!")
|
||||
req = H2Request(self, stream_id, reply_channel, None)
|
||||
|
||||
self.requests[stream_id] = req
|
||||
self.factory.reply_protocols[reply_channel] = req
|
||||
|
||||
path = headers[':path']
|
||||
query_string = b""
|
||||
if "?" in path: # h2 makes path a unicode
|
||||
path, query_string = path.encode().split(b"?", 1)
|
||||
|
||||
self.factory.channel_layer.send("http.request", {
|
||||
"reply_channel": reply_channel,
|
||||
"http_version": "2.0", # \o/
|
||||
"scheme": "http", # should be read from env/proxys headers ??
|
||||
"method" : headers[':method'],
|
||||
"path" : path, # asgi expects these as bytes
|
||||
"query_string" : query_string,
|
||||
"headers": headers,
|
||||
"body": b"", # this is populated on DataReceived event
|
||||
"client": [self.transport.getHost().host, self.transport.getHost().port],
|
||||
})
|
||||
|
||||
|
||||
|
||||
def dataFrameReceived(self, stream_id, data):
|
||||
self.requests[stream_id].dataReceived(data)
|
||||
|
||||
|
||||
|
||||
|
||||
class H2Factory(Factory):
|
||||
|
||||
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20):
|
||||
self.channel_layer = channel_layer
|
||||
self.action_logger = action_logger
|
||||
self.timeout = timeout
|
||||
self.websocket_timeout = websocket_timeout
|
||||
self.ping_interval = ping_interval
|
||||
# We track all sub-protocols for response channel mapping
|
||||
self.reply_protocols = {}
|
||||
# Make a factory for WebSocket protocols
|
||||
# self.ws_factory = WebSocketFactory(self)
|
||||
# self.ws_factory.protocol = WebSocketProtocol
|
||||
# self.ws_factory.reply_protocols = self.reply_protocols
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return H2Protocol(self)
|
||||
|
||||
|
||||
# copy pasta from http_protocol
|
||||
def dispatch_reply(self, channel, message):
|
||||
if channel.startswith("http") and isinstance(self.reply_protocols[channel], H2Request):
|
||||
self.reply_protocols[channel].serverResponse(message)
|
||||
# elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
|
||||
# if message.get("bytes", None):
|
||||
# self.reply_protocols[channel].serverSend(message["bytes"], True)
|
||||
# if message.get("text", None):
|
||||
# self.reply_protocols[channel].serverSend(message["text"], False)
|
||||
# if message.get("close", False):
|
||||
# self.reply_protocols[channel].serverClose()
|
||||
else:
|
||||
raise ValueError("Cannot dispatch message on channel %r" % channel)
|
||||
|
||||
# copy pasta from http protocol
|
||||
def reply_channels(self):
|
||||
return self.reply_protocols.keys()
|
||||
|
||||
|
||||
def log_action(self, protocol, action, details):
|
||||
"""
|
||||
Dispatches to any registered action logger, if there is one.
|
||||
"""
|
||||
if self.action_logger:
|
||||
self.action_logger(protocol, action, details)
|
||||
|
||||
def check_timeouts(self):
|
||||
"""
|
||||
Runs through all HTTP protocol instances and times them out if they've
|
||||
taken too long (and so their message is probably expired)
|
||||
"""
|
||||
for protocol in list(self.reply_protocols.values()):
|
||||
# Web timeout checking
|
||||
if isinstance(protocol, H2Request) and protocol.duration() > self.timeout:
|
||||
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
|
||||
# WebSocket timeout checking and keepalive ping sending
|
||||
#elif isinstance(protocol, WebSocketProtocol):
|
||||
# Timeout check
|
||||
# if protocol.duration() > self.websocket_timeout:
|
||||
# protocol.serverClose()
|
||||
# Ping check
|
||||
# else:
|
||||
# protocol.check_ping()
|
|
@ -11,6 +11,7 @@ class Server(object):
|
|||
def __init__(
|
||||
self,
|
||||
channel_layer,
|
||||
factory_class, # that's a factory factory ... meeeh
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
unix_socket=None,
|
||||
|
@ -31,9 +32,10 @@ class Server(object):
|
|||
# If they did not provide a websocket timeout, default it to the
|
||||
# channel layer's group_expiry value if present, or one day if not.
|
||||
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
|
||||
|
||||
self.factory_class = factory_class
|
||||
|
||||
def run(self):
|
||||
self.factory = HTTPFactory(
|
||||
self.factory = self.factory_class(
|
||||
self.channel_layer,
|
||||
self.action_logger,
|
||||
timeout=self.http_timeout,
|
||||
|
|
83
daphne/tests/test_h2.py
Normal file
83
daphne/tests/test_h2.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from unittest import TestCase
|
||||
from asgiref.inmemory import ChannelLayer
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
from ..http2_protocol import H2Factory
|
||||
from h2.connection import H2Connection
|
||||
import h2.events
|
||||
|
||||
class TestH2Protocol(TestCase):
|
||||
"""
|
||||
Tests that the HTTP protocol class correctly generates and parses messages.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.channel_layer = ChannelLayer()
|
||||
self.factory = H2Factory(self.channel_layer)
|
||||
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
|
||||
self.tr = proto_helpers.StringTransport()
|
||||
self.proto.makeConnection(self.tr)
|
||||
|
||||
|
||||
def assertStartsWith(self, data, prefix):
|
||||
real_prefix = data[:len(prefix)]
|
||||
self.assertEqual(real_prefix, prefix)
|
||||
|
||||
def test_basic(self):
|
||||
"""
|
||||
Tests basic HTTP parsing
|
||||
"""
|
||||
# Send a simple request to the protocol
|
||||
|
||||
conn = H2Connection()
|
||||
conn.initiate_connection()
|
||||
#self.tr.write(conn.data_to_send())
|
||||
self.proto.dataReceived(conn.data_to_send())
|
||||
conn.send_headers(1, [
|
||||
(':method', 'GET'),
|
||||
(':path', '/test/?foo=bar'),
|
||||
('user-agent', 'hyper-h2/yo'),
|
||||
], end_stream=True)
|
||||
self.proto.dataReceived(conn.data_to_send())
|
||||
|
||||
_, message = self.channel_layer.receive_many(["http.request"])
|
||||
self.assertEqual(message['http_version'], "2.0")
|
||||
self.assertEqual(message['method'], "GET")
|
||||
self.assertEqual(message['scheme'], "http")
|
||||
self.assertEqual(message['path'], b"/test/")
|
||||
self.assertEqual(message['query_string'], b"foo=bar")
|
||||
# self.assertEqual(message['headers'], [(b"user-agent", b"hyper-h2/yo")])
|
||||
self.assertFalse(message.get("body", None))
|
||||
self.assertTrue(message['reply_channel'])
|
||||
|
||||
# Send back an example response
|
||||
self.factory.dispatch_reply(
|
||||
message['reply_channel'],
|
||||
{
|
||||
"status": 201,
|
||||
"status_text": b"Created",
|
||||
"content": b"OH HAI",
|
||||
"headers": [[b"X-Test", b"Boom!"]],
|
||||
}
|
||||
)
|
||||
# Make sure that comes back right on the protocol
|
||||
data = self.tr.value()
|
||||
evs = conn.receive_data(data)
|
||||
# we should see a ResponseReceived
|
||||
hasResponse = False
|
||||
for e in evs :
|
||||
if isinstance(e, h2.events.ResponseReceived) :
|
||||
hasResponse = True
|
||||
headers = dict(e.headers)
|
||||
self.assertEqual(headers["x-test"],"Boom!")
|
||||
self.assertEqual(headers[":status"], "201")
|
||||
|
||||
if isinstance(e, h2.events.DataReceived):
|
||||
self.assertEqual(e.data, b"OH HAI")
|
||||
self.assertTrue(hasResponse)
|
||||
|
||||
# a DataReceived
|
||||
# a StreamEnded ??
|
||||
|
||||
# self.assertEqual(evs)
|
||||
# self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n")
|
Loading…
Reference in New Issue
Block a user