Parse X-Forwarded-Proto header (#136)

Adds the ability to use this header for HTTPS detection.
This commit is contained in:
Nick Sellen 2017-08-25 18:24:24 +01:00 committed by Andrew Godwin
parent 3161715238
commit 05bd4ac258
7 changed files with 75 additions and 29 deletions

View File

@ -213,6 +213,7 @@ class CommandLineInterface(object):
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None,
proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None,
proxy_forwarded_proto_header='X-Forwarded-Proto' if args.proxy_headers else None,
force_sync=args.force_sync, force_sync=args.force_sync,
) )
self.server.run() self.server.run()

View File

@ -80,12 +80,16 @@ class WebRequest(http.Request):
self.client_addr = None self.client_addr = None
self.server_addr = None self.server_addr = None
self.client_scheme = 'https' if self.isSecure() else 'http'
if self.factory.proxy_forwarded_address_header: if self.factory.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders, self.requestHeaders,
self.factory.proxy_forwarded_address_header, self.factory.proxy_forwarded_address_header,
self.factory.proxy_forwarded_port_header, self.factory.proxy_forwarded_port_header,
self.client_addr self.factory.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme
) )
# Check for unicodeish path (or it'll crash when trying to parse) # Check for unicodeish path (or it'll crash when trying to parse)
@ -166,7 +170,7 @@ class WebRequest(http.Request):
"method": self.method.decode("ascii"), "method": self.method.decode("ascii"),
"path": self.unquote(self.path), "path": self.unquote(self.path),
"root_path": self.root_path, "root_path": self.root_path,
"scheme": "https" if self.isSecure() else "http", "scheme": self.client_scheme,
"query_string": self.query_string, "query_string": self.query_string,
"headers": self.clean_headers, "headers": self.clean_headers,
"body": self.content.read(), "body": self.content.read(),

View File

@ -36,6 +36,7 @@ class Server(object):
root_path="", root_path="",
proxy_forwarded_address_header=None, proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None, proxy_forwarded_port_header=None,
proxy_forwarded_proto_header=None,
force_sync=False, force_sync=False,
verbosity=1, verbosity=1,
websocket_handshake_timeout=5 websocket_handshake_timeout=5
@ -66,6 +67,7 @@ class Server(object):
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
# If they did not provide a websocket timeout, default it to the # If they did not provide a websocket timeout, default it to the
# channel layer's group_expiry value if present, or one day if not. # channel layer's group_expiry value if present, or one day if not.
self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
@ -95,6 +97,7 @@ class Server(object):
root_path=self.root_path, root_path=self.root_path,
proxy_forwarded_address_header=self.proxy_forwarded_address_header, proxy_forwarded_address_header=self.proxy_forwarded_address_header,
proxy_forwarded_port_header=self.proxy_forwarded_port_header, proxy_forwarded_port_header=self.proxy_forwarded_port_header,
proxy_forwarded_proto_header=self.proxy_forwarded_proto_header,
websocket_handshake_timeout=self.websocket_handshake_timeout websocket_handshake_timeout=self.websocket_handshake_timeout
) )
if self.verbosity <= 1: if self.verbosity <= 1:

View File

@ -171,6 +171,7 @@ class TestProxyHandling(unittest.TestCase):
def test_x_forwarded_for_parsed(self): def test_x_forwarded_for_parsed(self):
self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' self.factory.proxy_forwarded_address_header = 'X-Forwarded-For'
self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port'
self.factory.proxy_forwarded_proto_header = 'X-Forwarded-Proto'
self.proto.dataReceived( self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" + b"Host: somewhere.com\r\n" +
@ -185,6 +186,7 @@ class TestProxyHandling(unittest.TestCase):
def test_x_forwarded_for_port_missing(self): def test_x_forwarded_for_port_missing(self):
self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' self.factory.proxy_forwarded_address_header = 'X-Forwarded-For'
self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port'
self.factory.proxy_forwarded_proto_header = 'X-Forwarded-Proto'
self.proto.dataReceived( self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" + b"Host: somewhere.com\r\n" +

View File

@ -16,11 +16,13 @@ class TestXForwardedForHttpParsing(TestCase):
def test_basic(self): def test_basic(self):
headers = Headers({ headers = Headers({
b'X-Forwarded-For': [b'10.1.2.3'], b'X-Forwarded-For': [b'10.1.2.3'],
b'X-Forwarded-Port': [b'1234'] b'X-Forwarded-Port': [b'1234'],
b'X-Forwarded-Proto': [b'https']
}) })
result = parse_x_forwarded_for(headers) result = parse_x_forwarded_for(headers)
self.assertEqual(result, ['10.1.2.3', 1234]) self.assertEqual(result, (['10.1.2.3', 1234], 'https'))
self.assertIsInstance(result[0], six.text_type) self.assertIsInstance(result[0][0], six.text_type)
self.assertIsInstance(result[1], six.text_type)
def test_address_only(self): def test_address_only(self):
headers = Headers({ headers = Headers({
@ -28,7 +30,7 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['10.1.2.3', 0] (['10.1.2.3', 0], None)
) )
def test_v6_address(self): def test_v6_address(self):
@ -37,7 +39,7 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['1043::a321:0001', 0] (['1043::a321:0001', 0], None)
) )
def test_multiple_proxys(self): def test_multiple_proxys(self):
@ -46,19 +48,39 @@ class TestXForwardedForHttpParsing(TestCase):
}) })
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['10.1.2.3', 0] (['10.1.2.3', 0], None)
) )
def test_original(self): def test_original_addr(self):
headers = Headers({}) headers = Headers({})
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), parse_x_forwarded_for(headers, original_addr=['127.0.0.1', 80]),
['127.0.0.1', 80] (['127.0.0.1', 80], None)
)
def test_original_proto(self):
headers = Headers({})
self.assertEqual(
parse_x_forwarded_for(headers, original_scheme='http'),
(None, 'http')
) )
def test_no_original(self): def test_no_original(self):
headers = Headers({}) headers = Headers({})
self.assertIsNone(parse_x_forwarded_for(headers)) self.assertEqual(
parse_x_forwarded_for(headers),
(None, None)
)
def test_address_and_proto(self):
headers = Headers({
b'X-Forwarded-For': [b'10.1.2.3'],
b'X-Forwarded-Proto': [b'https'],
})
self.assertEqual(
parse_x_forwarded_for(headers),
(['10.1.2.3', 0], 'https')
)
class TestXForwardedForWsParsing(TestCase): class TestXForwardedForWsParsing(TestCase):
@ -73,7 +95,7 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['10.1.2.3', 1234] (['10.1.2.3', 1234], None)
) )
def test_address_only(self): def test_address_only(self):
@ -82,7 +104,7 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['10.1.2.3', 0] (['10.1.2.3', 0], None)
) )
def test_v6_address(self): def test_v6_address(self):
@ -91,7 +113,7 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['1043::a321:0001', 0] (['1043::a321:0001', 0], None)
) )
def test_multiple_proxys(self): def test_multiple_proxys(self):
@ -100,16 +122,19 @@ class TestXForwardedForWsParsing(TestCase):
} }
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers), parse_x_forwarded_for(headers),
['10.1.2.3', 0] (['10.1.2.3', 0], None)
) )
def test_original(self): def test_original(self):
headers = {} headers = {}
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), parse_x_forwarded_for(headers, original_addr=['127.0.0.1', 80]),
['127.0.0.1', 80] (['127.0.0.1', 80], None)
) )
def test_no_original(self): def test_no_original(self):
headers = {} headers = {}
self.assertIsNone(parse_x_forwarded_for(headers)) self.assertEqual(
parse_x_forwarded_for(headers),
(None, None)
)

View File

@ -11,18 +11,22 @@ def header_value(headers, header_name):
def parse_x_forwarded_for(headers, def parse_x_forwarded_for(headers,
address_header_name='X-Forwarded-For', address_header_name='X-Forwarded-For',
port_header_name='X-Forwarded-Port', port_header_name='X-Forwarded-Port',
original=None): proto_header_name='X-Forwarded-Proto',
original_addr=None,
original_scheme=None):
""" """
Parses an X-Forwarded-For header and returns a host/port pair as a list. Parses an X-Forwarded-For header and returns a host/port pair as a list.
@param headers: The twisted-style object containing a request's headers @param headers: The twisted-style object containing a request's headers
@param address_header_name: The name of the expected host header @param address_header_name: The name of the expected host header
@param port_header_name: The name of the expected port header @param port_header_name: The name of the expected port header
@param original: A host/port pair that should be returned if the headers are not in the request @param proto_header_name: The name of the expected protocol header
@return: A list containing a host (string) as the first entry and a port (int) as the second. @param original_addr: A host/port pair that should be returned if the headers are not in the request
@param original_scheme: A scheme that should be returned if the headers are not in the request
@return: A tuple containing a list [host (string), port (int)] as the first entry and a proto (string) as the second
""" """
if not address_header_name: if not address_header_name:
return original return (original_addr, original_scheme)
# Convert twisted-style headers into dicts # Convert twisted-style headers into dicts
if isinstance(headers, Headers): if isinstance(headers, Headers):
@ -32,14 +36,15 @@ def parse_x_forwarded_for(headers,
headers = {name.lower(): values for name, values in headers.items()} headers = {name.lower(): values for name, values in headers.items()}
address_header_name = address_header_name.lower().encode("utf-8") address_header_name = address_header_name.lower().encode("utf-8")
result = original result_addr = original_addr
result_scheme = original_scheme
if address_header_name in headers: if address_header_name in headers:
address_value = header_value(headers, address_header_name) address_value = header_value(headers, address_header_name)
if ',' in address_value: if ',' in address_value:
address_value = address_value.split(",")[0].strip() address_value = address_value.split(",")[0].strip()
result = [address_value, 0] result_addr = [address_value, 0]
if port_header_name: if port_header_name:
# We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
@ -48,8 +53,13 @@ def parse_x_forwarded_for(headers,
if port_header_name in headers: if port_header_name in headers:
port_value = header_value(headers, port_header_name) port_value = header_value(headers, port_header_name)
try: try:
result[1] = int(port_value) result_addr[1] = int(port_value)
except ValueError: except ValueError:
pass pass
return result if proto_header_name:
proto_header_name = proto_header_name.lower().encode("utf-8")
if proto_header_name in headers:
result_scheme = header_value(headers, proto_header_name)
return result_addr, result_scheme

View File

@ -57,10 +57,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server_addr = None self.server_addr = None
if self.main_factory.proxy_forwarded_address_header: if self.main_factory.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for( self.client_addr, = parse_x_forwarded_for(
self.http_headers, self.http_headers,
self.main_factory.proxy_forwarded_address_header, self.main_factory.proxy_forwarded_address_header,
self.main_factory.proxy_forwarded_port_header, self.main_factory.proxy_forwarded_port_header,
self.main_factory.proxy_forwarded_proto_header,
self.client_addr self.client_addr
) )