diff --git a/daphne/cli.py b/daphne/cli.py index b27cb06..a149632 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -213,6 +213,7 @@ class CommandLineInterface(object): verbosity=args.verbosity, proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, + proxy_forwarded_proto_header='X-Forwarded-Proto' if args.proxy_headers else None, force_sync=args.force_sync, ) self.server.run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 32dc276..df9f93a 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -80,12 +80,16 @@ class WebRequest(http.Request): self.client_addr = None self.server_addr = None + self.client_scheme = 'https' if self.isSecure() else 'http' + if self.factory.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, self.client_scheme = parse_x_forwarded_for( self.requestHeaders, self.factory.proxy_forwarded_address_header, self.factory.proxy_forwarded_port_header, - self.client_addr + self.factory.proxy_forwarded_proto_header, + self.client_addr, + self.client_scheme ) # Check for unicodeish path (or it'll crash when trying to parse) @@ -166,7 +170,7 @@ class WebRequest(http.Request): "method": self.method.decode("ascii"), "path": self.unquote(self.path), "root_path": self.root_path, - "scheme": "https" if self.isSecure() else "http", + "scheme": self.client_scheme, "query_string": self.query_string, "headers": self.clean_headers, "body": self.content.read(), diff --git a/daphne/server.py b/daphne/server.py index 0ae9eda..fc9f761 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -36,6 +36,7 @@ class Server(object): root_path="", proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, + proxy_forwarded_proto_header=None, force_sync=False, verbosity=1, websocket_handshake_timeout=5 @@ -66,6 +67,7 @@ class Server(object): self.ping_timeout = ping_timeout self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_port_header = proxy_forwarded_port_header + self.proxy_forwarded_proto_header = proxy_forwarded_proto_header # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) @@ -95,6 +97,7 @@ class Server(object): root_path=self.root_path, proxy_forwarded_address_header=self.proxy_forwarded_address_header, proxy_forwarded_port_header=self.proxy_forwarded_port_header, + proxy_forwarded_proto_header=self.proxy_forwarded_proto_header, websocket_handshake_timeout=self.websocket_handshake_timeout ) if self.verbosity <= 1: diff --git a/daphne/tests/test_http_request.py b/daphne/tests/test_http_request.py index ef7922d..dd64ae3 100644 --- a/daphne/tests/test_http_request.py +++ b/daphne/tests/test_http_request.py @@ -171,6 +171,7 @@ class TestProxyHandling(unittest.TestCase): def test_x_forwarded_for_parsed(self): self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' + self.factory.proxy_forwarded_proto_header = 'X-Forwarded-Proto' self.proto.dataReceived( b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + b"Host: somewhere.com\r\n" + @@ -185,6 +186,7 @@ class TestProxyHandling(unittest.TestCase): def test_x_forwarded_for_port_missing(self): self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' + self.factory.proxy_forwarded_proto_header = 'X-Forwarded-Proto' self.proto.dataReceived( b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + b"Host: somewhere.com\r\n" + diff --git a/daphne/tests/test_utils.py b/daphne/tests/test_utils.py index 525fec9..9e962a6 100644 --- a/daphne/tests/test_utils.py +++ b/daphne/tests/test_utils.py @@ -16,11 +16,13 @@ class TestXForwardedForHttpParsing(TestCase): def test_basic(self): headers = Headers({ b'X-Forwarded-For': [b'10.1.2.3'], - b'X-Forwarded-Port': [b'1234'] + b'X-Forwarded-Port': [b'1234'], + b'X-Forwarded-Proto': [b'https'] }) result = parse_x_forwarded_for(headers) - self.assertEqual(result, ['10.1.2.3', 1234]) - self.assertIsInstance(result[0], six.text_type) + self.assertEqual(result, (['10.1.2.3', 1234], 'https')) + self.assertIsInstance(result[0][0], six.text_type) + self.assertIsInstance(result[1], six.text_type) def test_address_only(self): headers = Headers({ @@ -28,7 +30,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ['10.1.2.3', 0] + (['10.1.2.3', 0], None) ) def test_v6_address(self): @@ -37,7 +39,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ['1043::a321:0001', 0] + (['1043::a321:0001', 0], None) ) def test_multiple_proxys(self): @@ -46,19 +48,39 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ['10.1.2.3', 0] + (['10.1.2.3', 0], None) ) - def test_original(self): + def test_original_addr(self): headers = Headers({}) self.assertEqual( - parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), - ['127.0.0.1', 80] + parse_x_forwarded_for(headers, original_addr=['127.0.0.1', 80]), + (['127.0.0.1', 80], None) + ) + + def test_original_proto(self): + headers = Headers({}) + self.assertEqual( + parse_x_forwarded_for(headers, original_scheme='http'), + (None, 'http') ) def test_no_original(self): headers = Headers({}) - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual( + parse_x_forwarded_for(headers), + (None, None) + ) + + def test_address_and_proto(self): + headers = Headers({ + b'X-Forwarded-For': [b'10.1.2.3'], + b'X-Forwarded-Proto': [b'https'], + }) + self.assertEqual( + parse_x_forwarded_for(headers), + (['10.1.2.3', 0], 'https') + ) class TestXForwardedForWsParsing(TestCase): @@ -73,7 +95,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ['10.1.2.3', 1234] + (['10.1.2.3', 1234], None) ) def test_address_only(self): @@ -82,7 +104,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ['10.1.2.3', 0] + (['10.1.2.3', 0], None) ) def test_v6_address(self): @@ -91,7 +113,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ['1043::a321:0001', 0] + (['1043::a321:0001', 0], None) ) def test_multiple_proxys(self): @@ -100,16 +122,19 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ['10.1.2.3', 0] + (['10.1.2.3', 0], None) ) def test_original(self): headers = {} self.assertEqual( - parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), - ['127.0.0.1', 80] + parse_x_forwarded_for(headers, original_addr=['127.0.0.1', 80]), + (['127.0.0.1', 80], None) ) def test_no_original(self): headers = {} - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual( + parse_x_forwarded_for(headers), + (None, None) + ) diff --git a/daphne/utils.py b/daphne/utils.py index cb8043c..95ed331 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -11,18 +11,22 @@ def header_value(headers, header_name): def parse_x_forwarded_for(headers, address_header_name='X-Forwarded-For', port_header_name='X-Forwarded-Port', - original=None): + proto_header_name='X-Forwarded-Proto', + original_addr=None, + original_scheme=None): """ Parses an X-Forwarded-For header and returns a host/port pair as a list. @param headers: The twisted-style object containing a request's headers @param address_header_name: The name of the expected host header @param port_header_name: The name of the expected port header - @param original: A host/port pair that should be returned if the headers are not in the request - @return: A list containing a host (string) as the first entry and a port (int) as the second. + @param proto_header_name: The name of the expected protocol header + @param original_addr: A host/port pair that should be returned if the headers are not in the request + @param original_scheme: A scheme that should be returned if the headers are not in the request + @return: A tuple containing a list [host (string), port (int)] as the first entry and a proto (string) as the second """ if not address_header_name: - return original + return (original_addr, original_scheme) # Convert twisted-style headers into dicts if isinstance(headers, Headers): @@ -32,14 +36,15 @@ def parse_x_forwarded_for(headers, headers = {name.lower(): values for name, values in headers.items()} address_header_name = address_header_name.lower().encode("utf-8") - result = original + result_addr = original_addr + result_scheme = original_scheme if address_header_name in headers: address_value = header_value(headers, address_header_name) if ',' in address_value: address_value = address_value.split(",")[0].strip() - result = [address_value, 0] + result_addr = [address_value, 0] if port_header_name: # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For @@ -48,8 +53,13 @@ def parse_x_forwarded_for(headers, if port_header_name in headers: port_value = header_value(headers, port_header_name) try: - result[1] = int(port_value) + result_addr[1] = int(port_value) except ValueError: pass - return result + if proto_header_name: + proto_header_name = proto_header_name.lower().encode("utf-8") + if proto_header_name in headers: + result_scheme = header_value(headers, proto_header_name) + + return result_addr, result_scheme diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index a432285..08f818b 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -57,10 +57,11 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server_addr = None if self.main_factory.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, = parse_x_forwarded_for( self.http_headers, self.main_factory.proxy_forwarded_address_header, self.main_factory.proxy_forwarded_port_header, + self.main_factory.proxy_forwarded_proto_header, self.client_addr )