mirror of
				https://github.com/django/daphne.git
				synced 2025-11-04 09:37:32 +03:00 
			
		
		
		
	Add x-forwarded-proto support (#219)
This commit is contained in:
		
							parent
							
								
									adb622d4f5
								
							
						
					
					
						commit
						2f94210321
					
				| 
						 | 
					@ -214,5 +214,6 @@ class CommandLineInterface(object):
 | 
				
			||||||
            verbosity=args.verbosity,
 | 
					            verbosity=args.verbosity,
 | 
				
			||||||
            proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
 | 
					            proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
 | 
				
			||||||
            proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
 | 
					            proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
 | 
				
			||||||
 | 
					            proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.server.run()
 | 
					        self.server.run()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,13 +73,18 @@ class WebRequest(http.Request):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.client_addr = None
 | 
					                self.client_addr = None
 | 
				
			||||||
                self.server_addr = None
 | 
					                self.server_addr = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.client_scheme = "https" if self.isSecure() else "http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # See if we need to get the address from a proxy header instead
 | 
					            # See if we need to get the address from a proxy header instead
 | 
				
			||||||
            if self.server.proxy_forwarded_address_header:
 | 
					            if self.server.proxy_forwarded_address_header:
 | 
				
			||||||
                self.client_addr = parse_x_forwarded_for(
 | 
					                self.client_addr, self.client_scheme = parse_x_forwarded_for(
 | 
				
			||||||
                    self.requestHeaders,
 | 
					                    self.requestHeaders,
 | 
				
			||||||
                    self.server.proxy_forwarded_address_header,
 | 
					                    self.server.proxy_forwarded_address_header,
 | 
				
			||||||
                    self.server.proxy_forwarded_port_header,
 | 
					                    self.server.proxy_forwarded_port_header,
 | 
				
			||||||
                    self.client_addr
 | 
					                    self.server.proxy_forwarded_proto_header,
 | 
				
			||||||
 | 
					                    self.client_addr,
 | 
				
			||||||
 | 
					                    self.client_scheme
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            # Check for unicodeish path (or it'll crash when trying to parse)
 | 
					            # Check for unicodeish path (or it'll crash when trying to parse)
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -153,7 +158,7 @@ class WebRequest(http.Request):
 | 
				
			||||||
                    "method": self.method.decode("ascii"),
 | 
					                    "method": self.method.decode("ascii"),
 | 
				
			||||||
                    "path": unquote(self.path.decode("ascii")),
 | 
					                    "path": unquote(self.path.decode("ascii")),
 | 
				
			||||||
                    "root_path": self.root_path,
 | 
					                    "root_path": self.root_path,
 | 
				
			||||||
                    "scheme": "https" if self.isSecure() else "http",
 | 
					                    "scheme": self.client_scheme,
 | 
				
			||||||
                    "query_string": self.query_string,
 | 
					                    "query_string": self.query_string,
 | 
				
			||||||
                    "headers": self.clean_headers,
 | 
					                    "headers": self.clean_headers,
 | 
				
			||||||
                    "client": self.client_addr,
 | 
					                    "client": self.client_addr,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,6 +49,7 @@ class Server(object):
 | 
				
			||||||
        root_path="",
 | 
					        root_path="",
 | 
				
			||||||
        proxy_forwarded_address_header=None,
 | 
					        proxy_forwarded_address_header=None,
 | 
				
			||||||
        proxy_forwarded_port_header=None,
 | 
					        proxy_forwarded_port_header=None,
 | 
				
			||||||
 | 
					        proxy_forwarded_proto_header=None,
 | 
				
			||||||
        verbosity=1,
 | 
					        verbosity=1,
 | 
				
			||||||
        websocket_handshake_timeout=5,
 | 
					        websocket_handshake_timeout=5,
 | 
				
			||||||
        application_close_timeout=10,
 | 
					        application_close_timeout=10,
 | 
				
			||||||
| 
						 | 
					@ -67,6 +68,7 @@ class Server(object):
 | 
				
			||||||
        self.ping_timeout = ping_timeout
 | 
					        self.ping_timeout = ping_timeout
 | 
				
			||||||
        self.proxy_forwarded_address_header = proxy_forwarded_address_header
 | 
					        self.proxy_forwarded_address_header = proxy_forwarded_address_header
 | 
				
			||||||
        self.proxy_forwarded_port_header = proxy_forwarded_port_header
 | 
					        self.proxy_forwarded_port_header = proxy_forwarded_port_header
 | 
				
			||||||
 | 
					        self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
 | 
				
			||||||
        self.websocket_timeout = websocket_timeout
 | 
					        self.websocket_timeout = websocket_timeout
 | 
				
			||||||
        self.websocket_connect_timeout = websocket_connect_timeout
 | 
					        self.websocket_connect_timeout = websocket_connect_timeout
 | 
				
			||||||
        self.websocket_handshake_timeout = websocket_handshake_timeout
 | 
					        self.websocket_handshake_timeout = websocket_handshake_timeout
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,7 @@ class DaphneTestingInstance:
 | 
				
			||||||
        if self.xff:
 | 
					        if self.xff:
 | 
				
			||||||
            kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For"
 | 
					            kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For"
 | 
				
			||||||
            kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port"
 | 
					            kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port"
 | 
				
			||||||
 | 
					            kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto"
 | 
				
			||||||
        if self.http_timeout:
 | 
					        if self.http_timeout:
 | 
				
			||||||
            kwargs["http_timeout"] = self.http_timeout
 | 
					            kwargs["http_timeout"] = self.http_timeout
 | 
				
			||||||
        # Start up process
 | 
					        # Start up process
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,18 +25,22 @@ def header_value(headers, header_name):
 | 
				
			||||||
def parse_x_forwarded_for(headers,
 | 
					def parse_x_forwarded_for(headers,
 | 
				
			||||||
                          address_header_name="X-Forwarded-For",
 | 
					                          address_header_name="X-Forwarded-For",
 | 
				
			||||||
                          port_header_name="X-Forwarded-Port",
 | 
					                          port_header_name="X-Forwarded-Port",
 | 
				
			||||||
                          original=None):
 | 
					                          proto_header_name="X-Forwarded-Proto",
 | 
				
			||||||
 | 
					                          original_addr=None,
 | 
				
			||||||
 | 
					                          original_scheme=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Parses an X-Forwarded-For header and returns a host/port pair as a list.
 | 
					    Parses an X-Forwarded-For header and returns a host/port pair as a list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @param headers: The twisted-style object containing a request's headers
 | 
					    @param headers: The twisted-style object containing a request's headers
 | 
				
			||||||
    @param address_header_name: The name of the expected host header
 | 
					    @param address_header_name: The name of the expected host header
 | 
				
			||||||
    @param port_header_name: The name of the expected port header
 | 
					    @param port_header_name: The name of the expected port header
 | 
				
			||||||
    @param original: A host/port pair that should be returned if the headers are not in the request
 | 
					    @param proto_header_name: The name of the expected proto header
 | 
				
			||||||
 | 
					    @param original_addr: A host/port pair that should be returned if the headers are not in the request
 | 
				
			||||||
 | 
					    @param original_scheme: A scheme that should be returned if the headers are not in the request
 | 
				
			||||||
    @return: A list containing a host (string) as the first entry and a port (int) as the second.
 | 
					    @return: A list containing a host (string) as the first entry and a port (int) as the second.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if not address_header_name:
 | 
					    if not address_header_name:
 | 
				
			||||||
        return original
 | 
					        return original_addr, original_scheme
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Convert twisted-style headers into dicts
 | 
					    # Convert twisted-style headers into dicts
 | 
				
			||||||
    if isinstance(headers, Headers):
 | 
					    if isinstance(headers, Headers):
 | 
				
			||||||
| 
						 | 
					@ -49,14 +53,15 @@ def parse_x_forwarded_for(headers,
 | 
				
			||||||
    assert all(isinstance(name, bytes) for name in headers.keys())
 | 
					    assert all(isinstance(name, bytes) for name in headers.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    address_header_name = address_header_name.lower().encode("utf-8")
 | 
					    address_header_name = address_header_name.lower().encode("utf-8")
 | 
				
			||||||
    result = original
 | 
					    result_addr = original_addr
 | 
				
			||||||
 | 
					    result_scheme = original_scheme
 | 
				
			||||||
    if address_header_name in headers:
 | 
					    if address_header_name in headers:
 | 
				
			||||||
        address_value = header_value(headers, address_header_name)
 | 
					        address_value = header_value(headers, address_header_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if "," in address_value:
 | 
					        if "," in address_value:
 | 
				
			||||||
            address_value = address_value.split(",")[0].strip()
 | 
					            address_value = address_value.split(",")[0].strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        result = [address_value, 0]
 | 
					        result_addr = [address_value, 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if port_header_name:
 | 
					        if port_header_name:
 | 
				
			||||||
            # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
 | 
					            # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
 | 
				
			||||||
| 
						 | 
					@ -65,8 +70,13 @@ def parse_x_forwarded_for(headers,
 | 
				
			||||||
            if port_header_name in headers:
 | 
					            if port_header_name in headers:
 | 
				
			||||||
                port_value = header_value(headers, port_header_name)
 | 
					                port_value = header_value(headers, port_header_name)
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    result[1] = int(port_value)
 | 
					                    result_addr[1] = int(port_value)
 | 
				
			||||||
                except ValueError:
 | 
					                except ValueError:
 | 
				
			||||||
                    pass
 | 
					                    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return result
 | 
					        if proto_header_name:
 | 
				
			||||||
 | 
					            proto_header_name = proto_header_name.lower().encode("utf-8")
 | 
				
			||||||
 | 
					            if proto_header_name in headers:
 | 
				
			||||||
 | 
					                result_scheme = header_value(headers, proto_header_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return result_addr, result_scheme
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,10 +49,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
 | 
				
			||||||
                self.server_addr = None
 | 
					                self.server_addr = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.server.proxy_forwarded_address_header:
 | 
					            if self.server.proxy_forwarded_address_header:
 | 
				
			||||||
                self.client_addr = parse_x_forwarded_for(
 | 
					                self.client_addr, self.client_scheme = parse_x_forwarded_for(
 | 
				
			||||||
                    dict(self.clean_headers),
 | 
					                    dict(self.clean_headers),
 | 
				
			||||||
                    self.server.proxy_forwarded_address_header,
 | 
					                    self.server.proxy_forwarded_address_header,
 | 
				
			||||||
                    self.server.proxy_forwarded_port_header,
 | 
					                    self.server.proxy_forwarded_port_header,
 | 
				
			||||||
 | 
					                    self.server.proxy_forwarded_proto_header,
 | 
				
			||||||
                    self.client_addr
 | 
					                    self.client_addr
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            # Decode websocket subprotocol options
 | 
					            # Decode websocket subprotocol options
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,11 +15,13 @@ class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
    def test_basic(self):
 | 
					    def test_basic(self):
 | 
				
			||||||
        headers = Headers({
 | 
					        headers = Headers({
 | 
				
			||||||
            b"X-Forwarded-For": [b"10.1.2.3"],
 | 
					            b"X-Forwarded-For": [b"10.1.2.3"],
 | 
				
			||||||
            b"X-Forwarded-Port": [b"1234"]
 | 
					            b"X-Forwarded-Port": [b"1234"],
 | 
				
			||||||
 | 
					            b"X-Forwarded-Proto": [b"https"]
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        result = parse_x_forwarded_for(headers)
 | 
					        result = parse_x_forwarded_for(headers)
 | 
				
			||||||
        self.assertEqual(result, ["10.1.2.3", 1234])
 | 
					        self.assertEqual(result, (["10.1.2.3", 1234], "https"))
 | 
				
			||||||
        self.assertIsInstance(result[0], str)
 | 
					        self.assertIsInstance(result[0][0], str)
 | 
				
			||||||
 | 
					        self.assertIsInstance(result[1], str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_address_only(self):
 | 
					    def test_address_only(self):
 | 
				
			||||||
        headers = Headers({
 | 
					        headers = Headers({
 | 
				
			||||||
| 
						 | 
					@ -27,7 +29,7 @@ class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["10.1.2.3", 0]
 | 
					            (["10.1.2.3", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_v6_address(self):
 | 
					    def test_v6_address(self):
 | 
				
			||||||
| 
						 | 
					@ -36,7 +38,7 @@ class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["1043::a321:0001", 0]
 | 
					            (["1043::a321:0001", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_multiple_proxys(self):
 | 
					    def test_multiple_proxys(self):
 | 
				
			||||||
| 
						 | 
					@ -45,19 +47,19 @@ class TestXForwardedForHttpParsing(TestCase):
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["10.1.2.3", 0]
 | 
					            (["10.1.2.3", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_original(self):
 | 
					    def test_original(self):
 | 
				
			||||||
        headers = Headers({})
 | 
					        headers = Headers({})
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers, original=["127.0.0.1", 80]),
 | 
					            parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
 | 
				
			||||||
            ["127.0.0.1", 80]
 | 
					            (["127.0.0.1", 80], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_no_original(self):
 | 
					    def test_no_original(self):
 | 
				
			||||||
        headers = Headers({})
 | 
					        headers = Headers({})
 | 
				
			||||||
        self.assertIsNone(parse_x_forwarded_for(headers))
 | 
					        self.assertEqual(parse_x_forwarded_for(headers), (None, None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestXForwardedForWsParsing(TestCase):
 | 
					class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
| 
						 | 
					@ -69,10 +71,11 @@ class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
        headers = {
 | 
					        headers = {
 | 
				
			||||||
            b"X-Forwarded-For": b"10.1.2.3",
 | 
					            b"X-Forwarded-For": b"10.1.2.3",
 | 
				
			||||||
            b"X-Forwarded-Port": b"1234",
 | 
					            b"X-Forwarded-Port": b"1234",
 | 
				
			||||||
 | 
					            b"X-Forwarded-Proto": b"https",
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["10.1.2.3", 1234]
 | 
					            (["10.1.2.3", 1234], "https")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_address_only(self):
 | 
					    def test_address_only(self):
 | 
				
			||||||
| 
						 | 
					@ -81,7 +84,7 @@ class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["10.1.2.3", 0]
 | 
					            (["10.1.2.3", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_v6_address(self):
 | 
					    def test_v6_address(self):
 | 
				
			||||||
| 
						 | 
					@ -90,7 +93,7 @@ class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["1043::a321:0001", 0]
 | 
					            (["1043::a321:0001", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_multiple_proxies(self):
 | 
					    def test_multiple_proxies(self):
 | 
				
			||||||
| 
						 | 
					@ -99,16 +102,16 @@ class TestXForwardedForWsParsing(TestCase):
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers),
 | 
					            parse_x_forwarded_for(headers),
 | 
				
			||||||
            ["10.1.2.3", 0]
 | 
					            (["10.1.2.3", 0], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_original(self):
 | 
					    def test_original(self):
 | 
				
			||||||
        headers = {}
 | 
					        headers = {}
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            parse_x_forwarded_for(headers, original=["127.0.0.1", 80]),
 | 
					            parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
 | 
				
			||||||
            ["127.0.0.1", 80]
 | 
					            (["127.0.0.1", 80], None)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_no_original(self):
 | 
					    def test_no_original(self):
 | 
				
			||||||
        headers = {}
 | 
					        headers = {}
 | 
				
			||||||
        self.assertIsNone(parse_x_forwarded_for(headers))
 | 
					        self.assertEqual(parse_x_forwarded_for(headers), (None, None))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user