diff --git a/daphne/cli.py b/daphne/cli.py index d9d9fb4..7f42084 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -214,5 +214,6 @@ class CommandLineInterface(object): verbosity=args.verbosity, proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, + proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, ) self.server.run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index c4e17fe..915e475 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -73,13 +73,18 @@ class WebRequest(http.Request): else: self.client_addr = None self.server_addr = None + + self.client_scheme = "https" if self.isSecure() else "http" + # See if we need to get the address from a proxy header instead if self.server.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, self.client_scheme = parse_x_forwarded_for( self.requestHeaders, self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, - self.client_addr + self.server.proxy_forwarded_proto_header, + self.client_addr, + self.client_scheme ) # Check for unicodeish path (or it'll crash when trying to parse) try: @@ -153,7 +158,7 @@ class WebRequest(http.Request): "method": self.method.decode("ascii"), "path": unquote(self.path.decode("ascii")), "root_path": self.root_path, - "scheme": "https" if self.isSecure() else "http", + "scheme": self.client_scheme, "query_string": self.query_string, "headers": self.clean_headers, "client": self.client_addr, diff --git a/daphne/server.py b/daphne/server.py index cff80f2..dfae544 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -49,6 +49,7 @@ class Server(object): root_path="", proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, + proxy_forwarded_proto_header=None, verbosity=1, websocket_handshake_timeout=5, application_close_timeout=10, @@ -67,6 +68,7 @@ class Server(object): self.ping_timeout = ping_timeout self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_port_header = proxy_forwarded_port_header + self.proxy_forwarded_proto_header = proxy_forwarded_proto_header self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout diff --git a/daphne/testing.py b/daphne/testing.py index 4c87c51..e606952 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -37,6 +37,7 @@ class DaphneTestingInstance: if self.xff: kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For" kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port" + kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto" if self.http_timeout: kwargs["http_timeout"] = self.http_timeout # Start up process diff --git a/daphne/utils.py b/daphne/utils.py index cd9e86e..ad64439 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -25,18 +25,22 @@ def header_value(headers, header_name): def parse_x_forwarded_for(headers, address_header_name="X-Forwarded-For", port_header_name="X-Forwarded-Port", - original=None): + proto_header_name="X-Forwarded-Proto", + original_addr=None, + original_scheme=None): """ Parses an X-Forwarded-For header and returns a host/port pair as a list. @param headers: The twisted-style object containing a request's headers @param address_header_name: The name of the expected host header @param port_header_name: The name of the expected port header - @param original: A host/port pair that should be returned if the headers are not in the request + @param proto_header_name: The name of the expected proto header + @param original_addr: A host/port pair that should be returned if the headers are not in the request + @param original_scheme: A scheme that should be returned if the headers are not in the request @return: A list containing a host (string) as the first entry and a port (int) as the second. """ if not address_header_name: - return original + return original_addr, original_scheme # Convert twisted-style headers into dicts if isinstance(headers, Headers): @@ -49,14 +53,15 @@ def parse_x_forwarded_for(headers, assert all(isinstance(name, bytes) for name in headers.keys()) address_header_name = address_header_name.lower().encode("utf-8") - result = original + result_addr = original_addr + result_scheme = original_scheme if address_header_name in headers: address_value = header_value(headers, address_header_name) if "," in address_value: address_value = address_value.split(",")[0].strip() - result = [address_value, 0] + result_addr = [address_value, 0] if port_header_name: # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For @@ -65,8 +70,13 @@ def parse_x_forwarded_for(headers, if port_header_name in headers: port_value = header_value(headers, port_header_name) try: - result[1] = int(port_value) + result_addr[1] = int(port_value) except ValueError: pass - return result + if proto_header_name: + proto_header_name = proto_header_name.lower().encode("utf-8") + if proto_header_name in headers: + result_scheme = header_value(headers, proto_header_name) + + return result_addr, result_scheme diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 8ae8749..f0b7bda 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -49,10 +49,11 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server_addr = None if self.server.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, self.client_scheme = parse_x_forwarded_for( dict(self.clean_headers), self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, + self.server.proxy_forwarded_proto_header, self.client_addr ) # Decode websocket subprotocol options diff --git a/tests/test_utils.py b/tests/test_utils.py index 786b8c9..5dada0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,11 +15,13 @@ class TestXForwardedForHttpParsing(TestCase): def test_basic(self): headers = Headers({ b"X-Forwarded-For": [b"10.1.2.3"], - b"X-Forwarded-Port": [b"1234"] + b"X-Forwarded-Port": [b"1234"], + b"X-Forwarded-Proto": [b"https"] }) result = parse_x_forwarded_for(headers) - self.assertEqual(result, ["10.1.2.3", 1234]) - self.assertIsInstance(result[0], str) + self.assertEqual(result, (["10.1.2.3", 1234], "https")) + self.assertIsInstance(result[0][0], str) + self.assertIsInstance(result[1], str) def test_address_only(self): headers = Headers({ @@ -27,7 +29,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_v6_address(self): @@ -36,7 +38,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["1043::a321:0001", 0] + (["1043::a321:0001", 0], None) ) def test_multiple_proxys(self): @@ -45,19 +47,19 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_original(self): headers = Headers({}) self.assertEqual( - parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), - ["127.0.0.1", 80] + parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), + (["127.0.0.1", 80], None) ) def test_no_original(self): headers = Headers({}) - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual(parse_x_forwarded_for(headers), (None, None)) class TestXForwardedForWsParsing(TestCase): @@ -69,10 +71,11 @@ class TestXForwardedForWsParsing(TestCase): headers = { b"X-Forwarded-For": b"10.1.2.3", b"X-Forwarded-Port": b"1234", + b"X-Forwarded-Proto": b"https", } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 1234] + (["10.1.2.3", 1234], "https") ) def test_address_only(self): @@ -81,7 +84,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_v6_address(self): @@ -90,7 +93,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["1043::a321:0001", 0] + (["1043::a321:0001", 0], None) ) def test_multiple_proxies(self): @@ -99,16 +102,16 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_original(self): headers = {} self.assertEqual( - parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), - ["127.0.0.1", 80] + parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), + (["127.0.0.1", 80], None) ) def test_no_original(self): headers = {} - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual(parse_x_forwarded_for(headers), (None, None))