Add x-forwarded-proto support (#219)

This commit is contained in:
Nick Sellen 2018-07-24 22:25:03 +02:00 committed by Andrew Godwin
parent adb622d4f5
commit 2f94210321
7 changed files with 50 additions and 27 deletions

View File

@ -214,5 +214,6 @@ class CommandLineInterface(object):
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
) )
self.server.run() self.server.run()

View File

@ -73,13 +73,18 @@ class WebRequest(http.Request):
else: else:
self.client_addr = None self.client_addr = None
self.server_addr = None self.server_addr = None
self.client_scheme = "https" if self.isSecure() else "http"
# See if we need to get the address from a proxy header instead # See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header: if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders, self.requestHeaders,
self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.client_addr self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme
) )
# Check for unicodeish path (or it'll crash when trying to parse) # Check for unicodeish path (or it'll crash when trying to parse)
try: try:
@ -153,7 +158,7 @@ class WebRequest(http.Request):
"method": self.method.decode("ascii"), "method": self.method.decode("ascii"),
"path": unquote(self.path.decode("ascii")), "path": unquote(self.path.decode("ascii")),
"root_path": self.root_path, "root_path": self.root_path,
"scheme": "https" if self.isSecure() else "http", "scheme": self.client_scheme,
"query_string": self.query_string, "query_string": self.query_string,
"headers": self.clean_headers, "headers": self.clean_headers,
"client": self.client_addr, "client": self.client_addr,

View File

@ -49,6 +49,7 @@ class Server(object):
root_path="", root_path="",
proxy_forwarded_address_header=None, proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None, proxy_forwarded_port_header=None,
proxy_forwarded_proto_header=None,
verbosity=1, verbosity=1,
websocket_handshake_timeout=5, websocket_handshake_timeout=5,
application_close_timeout=10, application_close_timeout=10,
@ -67,6 +68,7 @@ class Server(object):
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout self.websocket_handshake_timeout = websocket_handshake_timeout

View File

@ -37,6 +37,7 @@ class DaphneTestingInstance:
if self.xff: if self.xff:
kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For" kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For"
kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port" kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port"
kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto"
if self.http_timeout: if self.http_timeout:
kwargs["http_timeout"] = self.http_timeout kwargs["http_timeout"] = self.http_timeout
# Start up process # Start up process

View File

@ -25,18 +25,22 @@ def header_value(headers, header_name):
def parse_x_forwarded_for(headers, def parse_x_forwarded_for(headers,
address_header_name="X-Forwarded-For", address_header_name="X-Forwarded-For",
port_header_name="X-Forwarded-Port", port_header_name="X-Forwarded-Port",
original=None): proto_header_name="X-Forwarded-Proto",
original_addr=None,
original_scheme=None):
""" """
Parses an X-Forwarded-For header and returns a host/port pair as a list. Parses an X-Forwarded-For header and returns a host/port pair as a list.
@param headers: The twisted-style object containing a request's headers @param headers: The twisted-style object containing a request's headers
@param address_header_name: The name of the expected host header @param address_header_name: The name of the expected host header
@param port_header_name: The name of the expected port header @param port_header_name: The name of the expected port header
@param original: A host/port pair that should be returned if the headers are not in the request @param proto_header_name: The name of the expected proto header
@param original_addr: A host/port pair that should be returned if the headers are not in the request
@param original_scheme: A scheme that should be returned if the headers are not in the request
@return: A list containing a host (string) as the first entry and a port (int) as the second. @return: A list containing a host (string) as the first entry and a port (int) as the second.
""" """
if not address_header_name: if not address_header_name:
return original return original_addr, original_scheme
# Convert twisted-style headers into dicts # Convert twisted-style headers into dicts
if isinstance(headers, Headers): if isinstance(headers, Headers):
@ -49,14 +53,15 @@ def parse_x_forwarded_for(headers,
assert all(isinstance(name, bytes) for name in headers.keys()) assert all(isinstance(name, bytes) for name in headers.keys())
address_header_name = address_header_name.lower().encode("utf-8") address_header_name = address_header_name.lower().encode("utf-8")
result = original result_addr = original_addr
result_scheme = original_scheme
if address_header_name in headers: if address_header_name in headers:
address_value = header_value(headers, address_header_name) address_value = header_value(headers, address_header_name)
if "," in address_value: if "," in address_value:
address_value = address_value.split(",")[0].strip() address_value = address_value.split(",")[0].strip()
result = [address_value, 0] result_addr = [address_value, 0]
if port_header_name: if port_header_name:
# We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
@ -65,8 +70,13 @@ def parse_x_forwarded_for(headers,
if port_header_name in headers: if port_header_name in headers:
port_value = header_value(headers, port_header_name) port_value = header_value(headers, port_header_name)
try: try:
result[1] = int(port_value) result_addr[1] = int(port_value)
except ValueError: except ValueError:
pass pass
return result if proto_header_name:
proto_header_name = proto_header_name.lower().encode("utf-8")
if proto_header_name in headers:
result_scheme = header_value(headers, proto_header_name)
return result_addr, result_scheme

View File

@ -49,10 +49,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server_addr = None self.server_addr = None
if self.server.proxy_forwarded_address_header: if self.server.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr, self.client_scheme = parse_x_forwarded_for(
dict(self.clean_headers), dict(self.clean_headers),
self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr self.client_addr
) )
# Decode websocket subprotocol options # Decode websocket subprotocol options

View File

@ -15,11 +15,13 @@ class TestXForwardedForHttpParsing(TestCase):
def test_basic(self): def test_basic(self):
headers = Headers({ headers = Headers({
b"X-Forwarded-For": [b"10.1.2.3"], b"X-Forwarded-For": [b"10.1.2.3"],
b"X-Forwarded-Port": [b"1234"] b"X-Forwarded-Port": [b"1234"],
b"X-Forwarded-Proto": [b"https"]
}) })
result = parse_x_forwarded_for(headers) result = parse_x_forwarded_for(headers)
self.assertEqual(result, ["10.1.2.3", 1234]) self.assertEqual(result, (["10.1.2.3", 1234], "https"))
self.assertIsInstance(result[0], str) self.assertIsInstance(result[0][0], str)
self.assertIsInstance(result[1], str)
def test_address_only(self): def test_address_only(self):
headers = Headers({ headers = Headers({
@ -27,7 +29,7 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["10.1.2.3", 0] (["10.1.2.3", 0], None)
) )
def test_v6_address(self): def test_v6_address(self):
@ -36,7 +38,7 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["1043::a321:0001", 0] (["1043::a321:0001", 0], None)
) )
def test_multiple_proxys(self): def test_multiple_proxys(self):
@ -45,19 +47,19 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["10.1.2.3", 0] (["10.1.2.3", 0], None)
) )
def test_original(self): def test_original(self):
headers = Headers({}) headers = Headers({})
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
["127.0.0.1", 80] (["127.0.0.1", 80], None)
) )
def test_no_original(self): def test_no_original(self):
headers = Headers({}) headers = Headers({})
self.assertIsNone(parse_x_forwarded_for(headers)) self.assertEqual(parse_x_forwarded_for(headers), (None, None))
class TestXForwardedForWsParsing(TestCase): class TestXForwardedForWsParsing(TestCase):
@ -69,10 +71,11 @@ class TestXForwardedForWsParsing(TestCase):
headers = { headers = {
b"X-Forwarded-For": b"10.1.2.3", b"X-Forwarded-For": b"10.1.2.3",
b"X-Forwarded-Port": b"1234", b"X-Forwarded-Port": b"1234",
b"X-Forwarded-Proto": b"https",
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["10.1.2.3", 1234] (["10.1.2.3", 1234], "https")
) )
def test_address_only(self): def test_address_only(self):
@ -81,7 +84,7 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["10.1.2.3", 0] (["10.1.2.3", 0], None)
) )
def test_v6_address(self): def test_v6_address(self):
@ -90,7 +93,7 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["1043::a321:0001", 0] (["1043::a321:0001", 0], None)
) )
def test_multiple_proxies(self): def test_multiple_proxies(self):
@ -99,16 +102,16 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
["10.1.2.3", 0] (["10.1.2.3", 0], None)
) )
def test_original(self): def test_original(self):
headers = {} headers = {}
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
["127.0.0.1", 80] (["127.0.0.1", 80], None)
) )
def test_no_original(self): def test_no_original(self):
headers = {} headers = {}
self.assertIsNone(parse_x_forwarded_for(headers)) self.assertEqual(parse_x_forwarded_for(headers), (None, None))