Fixed #12: Crash on receiving high byte in path

This commit is contained in:
Andrew Godwin 2016-05-18 10:08:15 -07:00
parent 81d99a34d3
commit d786329abb

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import logging import logging
import six import six
import time import time
import traceback
from six.moves.urllib_parse import unquote from six.moves.urllib_parse import unquote
from twisted.protocols.policies import ProtocolWrapper from twisted.protocols.policies import ProtocolWrapper
@ -53,11 +54,26 @@ class WebRequest(http.Request):
self._got_response_start = False self._got_response_start = False
def process(self): def process(self):
try:
self.request_start = time.time() self.request_start = time.time()
# Get upgrade header # Get upgrade header
upgrade_header = None upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"): if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
self.client_addr = [self.client.host, self.client.port]
self.server_addr = [self.host.host, self.host.port]
else:
self.client_addr = None
self.server_addr = None
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string # Calculate query string
self.query_string = b"" self.query_string = b""
if b"?" in self.uri: if b"?" in self.uri:
@ -103,13 +119,6 @@ class WebRequest(http.Request):
self.clean_headers.append((name.lower(), value)) self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.reply_channel) logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0) self.content.seek(0, 0)
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
self.client_addr = [self.client.host, self.client.port]
self.server_addr = [self.host.host, self.host.port]
else:
self.client_addr = None
self.server_addr = None
# Send message # Send message
try: try:
self.factory.channel_layer.send("http.request", { self.factory.channel_layer.send("http.request", {
@ -128,6 +137,9 @@ class WebRequest(http.Request):
except self.factory.channel_layer.ChannelFull: except self.factory.channel_layer.ChannelFull:
# Channel is too full; reject request with 503 # Channel is too full; reject request with 503
self.basic_error(503, b"Service Unavailable", "Request queue full.") self.basic_error(503, b"Service Unavailable", "Request queue full.")
except Exception as e:
logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
@classmethod @classmethod
def unquote(cls, value): def unquote(cls, value):
@ -195,6 +207,7 @@ class WebRequest(http.Request):
if not message.get("more_content", False): if not message.get("more_content", False):
self.finish() self.finish()
logger.debug("HTTP response complete for %s", self.reply_channel) logger.debug("HTTP response complete for %s", self.reply_channel)
try:
self.factory.log_action("http", "complete", { self.factory.log_action("http", "complete", {
"path": self.path.decode("ascii"), "path": self.path.decode("ascii"),
"status": self.code, "status": self.code,
@ -203,6 +216,8 @@ class WebRequest(http.Request):
"time_taken": self.duration(), "time_taken": self.duration(),
"size": self.sentLength, "size": self.sentLength,
}) })
except Exception as e:
logging.error(traceback.format_exc())
else: else:
logger.debug("HTTP response chunk for %s", self.reply_channel) logger.debug("HTTP response chunk for %s", self.reply_channel)