Fixed #12: Crash on receiving high byte in path

This commit is contained in:
Andrew Godwin 2016-05-18 10:08:15 -07:00
parent 81d99a34d3
commit d786329abb

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import logging
import six
import time
import traceback
from six.moves.urllib_parse import unquote
from twisted.protocols.policies import ProtocolWrapper
@ -53,56 +54,12 @@ class WebRequest(http.Request):
self._got_response_start = False
def process(self):
self.request_start = time.time()
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer())
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Port across transport
protocol.set_main_factory(self.factory)
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a'
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a'
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.reply_channel = None
# Boring old HTTP.
else:
# Sanitize and decode headers
self.clean_headers = []
for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
for value in values:
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0)
try:
self.request_start = time.time()
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
self.client_addr = [self.client.host, self.client.port]
@ -110,24 +67,79 @@ class WebRequest(http.Request):
else:
self.client_addr = None
self.server_addr = None
# Send message
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.factory.channel_layer.send("http.request", {
"reply_channel": self.reply_channel,
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": "1.1",
"method": self.method.decode("ascii"),
"path": self.unquote(self.path),
"scheme": "http",
"query_string": self.unquote(self.query_string),
"headers": self.clean_headers,
"body": self.content.read(),
"client": self.client_addr,
"server": self.server_addr,
})
except self.factory.channel_layer.ChannelFull:
# Channel is too full; reject request with 503
self.basic_error(503, b"Service Unavailable", "Request queue full.")
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer())
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Port across transport
protocol.set_main_factory(self.factory)
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a'
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a'
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.reply_channel = None
# Boring old HTTP.
else:
# Sanitize and decode headers
self.clean_headers = []
for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
for value in values:
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0)
# Send message
try:
self.factory.channel_layer.send("http.request", {
"reply_channel": self.reply_channel,
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": "1.1",
"method": self.method.decode("ascii"),
"path": self.unquote(self.path),
"scheme": "http",
"query_string": self.unquote(self.query_string),
"headers": self.clean_headers,
"body": self.content.read(),
"client": self.client_addr,
"server": self.server_addr,
})
except self.factory.channel_layer.ChannelFull:
# Channel is too full; reject request with 503
self.basic_error(503, b"Service Unavailable", "Request queue full.")
except Exception as e:
logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "HTTP processing error")
@classmethod
def unquote(cls, value):
@ -195,14 +207,17 @@ class WebRequest(http.Request):
if not message.get("more_content", False):
self.finish()
logger.debug("HTTP response complete for %s", self.reply_channel)
self.factory.log_action("http", "complete", {
"path": self.path.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
try:
self.factory.log_action("http", "complete", {
"path": self.path.decode("ascii"),
"status": self.code,
"method": self.method.decode("ascii"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
logging.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.reply_channel)