From 48889827ea53d69266ea96b4b0e75a4ad9d13ccb Mon Sep 17 00:00:00 2001 From: Florent D'halluin Date: Mon, 4 Apr 2016 22:19:17 +0200 Subject: [PATCH] streaming body for both req and response --- daphne/http2_protocol.py | 157 ++++++++++++++++++++++++++++++--------- 1 file changed, 121 insertions(+), 36 deletions(-) diff --git a/daphne/http2_protocol.py b/daphne/http2_protocol.py index 4be6faf..449ed43 100644 --- a/daphne/http2_protocol.py +++ b/daphne/http2_protocol.py @@ -25,19 +25,70 @@ class H2Request(object): self.start_time = time.time() self.reply_channel = reply_channel self.body_channel = body_channel + self.response_started = False + self.headers = {} + self._header_sent = False # have header message been sent to channel layer ? + + def setHeaders(self, headers) : + self.headers = headers + self.body = b"" + + def sendHeaders(self): + + path = self.headers[':path'] + query_string = b"" + if "?" in path: # h2 makes path a unicode + path, query_string = path.encode().split(b"?", 1) + + # clean up ':' prefixed headers + headers_ = {} + for k,v in self.headers.items() : + if not k.startswith(':'): + headers_[k] = v + + # not post : wait for body before sending message + self.protocol.factory.channel_layer.send("http.request", { + "reply_channel": self.reply_channel, + "http_version": "2.0", # \o/ + "scheme": "http", # should be read from env/proxys headers ?? + "method" : self.headers[':method'], + "path" : path, # asgi expects these as bytes + "query_string" : query_string, + "headers": headers_, + "body": self.body, # this is populated on DataReceived event + "client": [self.protocol.transport.getHost().host, + self.protocol.transport.getHost().port], + }) + + self._header_send = True def serverResponse(self, message ): - print(message) - self.protocol.makeResponse(self.stream_id, message) - del self.protocol.factory.reply_protocols[self.reply_channel] + if "status" in message : + assert(not self.response_started) + self.response_started = True + self.protocol.makeResponse(self.stream_id, message) + # only if we are done + else : + assert(self.response_started) + self.protocol.sendData(self.stream_id, + message["content"], + message["more_content"]) - def DataReceived(self, data) : - """ chunk of body """ - self.protocol.factory.channel_layer.send(self.body_channel, { - content: data, - closed: False, # send a True to signal interruption of requests - more_content: False, - }) + if(not message.get("more_content", False)) : + del self.protocol.factory.reply_protocols[self.reply_channel] + + + + def dataReceived(self, data) : + """ chunk of body received """ + if(self._header_sent and self.body_channel) : + self.protocol.factory.channel_layer.send(self.body_channel, { + "content": data, + "closed": False, # send a True to signal interruption of requests + "more_content": False, # we just can't know that .. + }) + else : + print("Barf!") def duration(self): return time.time() - self.start_time @@ -71,11 +122,11 @@ class H2Protocol(Protocol): self.requestReceived(event.headers, event.stream_id) elif isinstance(event, DataReceived): self.dataFrameReceived(event.stream_id, event.data) - #elif isinstance(event, WindowUpdated): - # self.windowUpdated(event) + elif isinstance(event, WindowUpdated): + self.windowUpdated(event) def makeResponse(self, stream_id, message) : - + print("responding", message) response_headers = [ (':status', str(message["status"])), ('server', 'twisted-h2'), @@ -88,9 +139,11 @@ class H2Protocol(Protocol): self.transport.write(self.conn.data_to_send()) # write content .. Chnk this !! - self.conn.send_data(stream_id, message["content"], True) - self.transport.write(self.conn.data_to_send()) - + more_content = message.get('more_content', False) + # that's a twisted deferred, if you don't add a call back, + # this gets discarded + d = self.sendData(stream_id, message["content"], more_content) + d.addErrback(lambda e: print("error in send data", e)) def requestReceived(self, headers, stream_id): @@ -98,36 +151,68 @@ class H2Protocol(Protocol): reply_channel = self.factory.channel_layer.new_channel("http.response!") - # how do we know if there's a pending body ?? - # body_channel = self.factory.channel_layer.new_channel("http.request.body!") - req = H2Request(self, stream_id, reply_channel, None) + body_channel = None + if(headers[':method'] == 'POST'): + body_channel = self.factory.channel_layer.new_channel("http.request.body!") + # body_channel = + req = H2Request(self, stream_id, reply_channel, body_channel) + req.setHeaders(headers) self.requests[stream_id] = req self.factory.reply_protocols[reply_channel] = req - path = headers[':path'] - query_string = b"" - if "?" in path: # h2 makes path a unicode - path, query_string = path.encode().split(b"?", 1) - - self.factory.channel_layer.send("http.request", { - "reply_channel": reply_channel, - "http_version": "2.0", # \o/ - "scheme": "http", # should be read from env/proxys headers ?? - "method" : headers[':method'], - "path" : path, # asgi expects these as bytes - "query_string" : query_string, - "headers": headers, - "body": b"", # this is populated on DataReceived event - "client": [self.transport.getHost().host, self.transport.getHost().port], - }) + # send the request to channel layer, or wait for body + req.sendHeaders() + @inlineCallbacks + def sendData(self, stream_id, data, more_content=False): + # chunks and enqueue data + send_more = True + msg_size = len(data) + offset = 0 + while send_more : + print("waigint for flow control") + while not self.conn.remote_flow_control_window(stream_id) : + # do we have a flow window ? + yield self.wait_for_flow_control(stream_id) + + chunk_size = min(self.conn.remote_flow_control_window(stream_id),READ_CHUNK_SIZE) + + # hopefully, both are bigger than message data + if (msg_size - offset) < chunk_size : + send_more = False + end_chunk = offset + chunk_size + 1 + else : + end_chunk = msg_size + 1 + + chunk = data[offset:end_chunk] + # if more_content, keep request active + done = not ( send_more or more_content) + self.conn.send_data(stream_id, chunk, done) + self.transport.write(self.conn.data_to_send()) + + + def wait_for_flow_control(self, stream_id): + d = Deferred() + self._flow_control_deferreds[stream_id] = d + return d def dataFrameReceived(self, stream_id, data): self.requests[stream_id].dataReceived(data) - + def windowUpdated(self, event): + stream_id = event.stream_id + print("window flow ctrl", stream_id) + if stream_id and stream_id in self._flow_control_deferreds: + d = self._flow_control_deferreds.pop(stream_id) + d.callback(event.delta) + elif not stream_id: + # fire them all.. + for d in self._flow_control_deferreds.values(): + d.callback(event.delta) + self._flow_control_deferreds = {} + return class H2Factory(Factory):