diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index db18fb8..649fc55 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import logging import six import time +import traceback from six.moves.urllib_parse import unquote from twisted.protocols.policies import ProtocolWrapper @@ -53,56 +54,12 @@ class WebRequest(http.Request): self._got_response_start = False def process(self): - self.request_start = time.time() - # Get upgrade header - upgrade_header = None - if self.requestHeaders.hasHeader(b"Upgrade"): - upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] - # Calculate query string - self.query_string = b"" - if b"?" in self.uri: - self.query_string = self.uri.split(b"?", 1)[1] - # Is it WebSocket? IS IT?! - if upgrade_header and upgrade_header.lower() == b"websocket": - # Make WebSocket protocol to hand off to - protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer()) - if not protocol: - # If protocol creation fails, we signal "internal server error" - self.setResponseCode(500) - logger.warn("Could not make WebSocket protocol") - self.finish() - # Port across transport - protocol.set_main_factory(self.factory) - transport, self.transport = self.transport, None - if isinstance(transport, ProtocolWrapper): - # i.e. TLS is a wrapping protocol - transport.wrappedProtocol = protocol - else: - transport.protocol = protocol - protocol.makeConnection(transport) - # Re-inject request - data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a' - for h in self.requestHeaders.getAllRawHeaders(): - data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a' - data += b"\x0d\x0a" - data += self.content.read() - protocol.dataReceived(data) - # Remove our HTTP reply channel association - logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) - del self.factory.reply_protocols[self.reply_channel] - self.reply_channel = None - # Boring old HTTP. - else: - # Sanitize and decode headers - self.clean_headers = [] - for name, values in self.requestHeaders.getAllRawHeaders(): - # Prevent CVE-2015-0219 - if b"_" in name: - continue - for value in values: - self.clean_headers.append((name.lower(), value)) - logger.debug("HTTP %s request for %s", self.method, self.reply_channel) - self.content.seek(0, 0) + try: + self.request_start = time.time() + # Get upgrade header + upgrade_header = None + if self.requestHeaders.hasHeader(b"Upgrade"): + upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] # Get client address if possible if hasattr(self.client, "host") and hasattr(self.client, "port"): self.client_addr = [self.client.host, self.client.port] @@ -110,24 +67,79 @@ class WebRequest(http.Request): else: self.client_addr = None self.server_addr = None - # Send message + # Check for unicodeish path (or it'll crash when trying to parse) try: - self.factory.channel_layer.send("http.request", { - "reply_channel": self.reply_channel, - # TODO: Correctly say if it's 1.1 or 1.0 - "http_version": "1.1", - "method": self.method.decode("ascii"), - "path": self.unquote(self.path), - "scheme": "http", - "query_string": self.unquote(self.query_string), - "headers": self.clean_headers, - "body": self.content.read(), - "client": self.client_addr, - "server": self.server_addr, - }) - except self.factory.channel_layer.ChannelFull: - # Channel is too full; reject request with 503 - self.basic_error(503, b"Service Unavailable", "Request queue full.") + self.path.decode("ascii") + except UnicodeDecodeError: + self.path = b"/" + self.basic_error(400, b"Bad Request", "Invalid characters in path") + return + # Calculate query string + self.query_string = b"" + if b"?" in self.uri: + self.query_string = self.uri.split(b"?", 1)[1] + # Is it WebSocket? IS IT?! + if upgrade_header and upgrade_header.lower() == b"websocket": + # Make WebSocket protocol to hand off to + protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer()) + if not protocol: + # If protocol creation fails, we signal "internal server error" + self.setResponseCode(500) + logger.warn("Could not make WebSocket protocol") + self.finish() + # Port across transport + protocol.set_main_factory(self.factory) + transport, self.transport = self.transport, None + if isinstance(transport, ProtocolWrapper): + # i.e. TLS is a wrapping protocol + transport.wrappedProtocol = protocol + else: + transport.protocol = protocol + protocol.makeConnection(transport) + # Re-inject request + data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a' + for h in self.requestHeaders.getAllRawHeaders(): + data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a' + data += b"\x0d\x0a" + data += self.content.read() + protocol.dataReceived(data) + # Remove our HTTP reply channel association + logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel) + del self.factory.reply_protocols[self.reply_channel] + self.reply_channel = None + # Boring old HTTP. + else: + # Sanitize and decode headers + self.clean_headers = [] + for name, values in self.requestHeaders.getAllRawHeaders(): + # Prevent CVE-2015-0219 + if b"_" in name: + continue + for value in values: + self.clean_headers.append((name.lower(), value)) + logger.debug("HTTP %s request for %s", self.method, self.reply_channel) + self.content.seek(0, 0) + # Send message + try: + self.factory.channel_layer.send("http.request", { + "reply_channel": self.reply_channel, + # TODO: Correctly say if it's 1.1 or 1.0 + "http_version": "1.1", + "method": self.method.decode("ascii"), + "path": self.unquote(self.path), + "scheme": "http", + "query_string": self.unquote(self.query_string), + "headers": self.clean_headers, + "body": self.content.read(), + "client": self.client_addr, + "server": self.server_addr, + }) + except self.factory.channel_layer.ChannelFull: + # Channel is too full; reject request with 503 + self.basic_error(503, b"Service Unavailable", "Request queue full.") + except Exception as e: + logger.error(traceback.format_exc()) + self.basic_error(500, b"Internal Server Error", "HTTP processing error") @classmethod def unquote(cls, value): @@ -195,14 +207,17 @@ class WebRequest(http.Request): if not message.get("more_content", False): self.finish() logger.debug("HTTP response complete for %s", self.reply_channel) - self.factory.log_action("http", "complete", { - "path": self.path.decode("ascii"), - "status": self.code, - "method": self.method.decode("ascii"), - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - "time_taken": self.duration(), - "size": self.sentLength, - }) + try: + self.factory.log_action("http", "complete", { + "path": self.path.decode("ascii"), + "status": self.code, + "method": self.method.decode("ascii"), + "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, + "time_taken": self.duration(), + "size": self.sentLength, + }) + except Exception as e: + logging.error(traceback.format_exc()) else: logger.debug("HTTP response chunk for %s", self.reply_channel)