diff --git a/daphne/cli.py b/daphne/cli.py index 6f07b19..873946d 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -96,6 +96,14 @@ class CommandLineInterface(object): help='The setting for the ASGI root_path variable', default="", ) + self.parser.add_argument( + '--proxy-headers', + dest='proxy_headers', + help='Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the ' + 'client address', + default=False, + action='store_true', + ) @classmethod def entrypoint(cls): @@ -154,4 +162,6 @@ class CommandLineInterface(object): ws_protocols=args.ws_protocols, root_path=args.root_path, verbosity=args.verbosity, + proxy_forwarded_address_header='X-Forwarded-For' if args.proxy_headers else None, + proxy_forwarded_port_header='X-Forwarded-Port' if args.proxy_headers else None, ).run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 5b6158f..ce281d8 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -9,6 +9,7 @@ from six.moves.urllib_parse import unquote, unquote_plus from twisted.protocols.policies import ProtocolWrapper from twisted.web import http +from .utils import parse_x_forwarded_for from .ws_protocol import WebSocketProtocol, WebSocketFactory logger = logging.getLogger(__name__) @@ -67,6 +68,15 @@ class WebRequest(http.Request): else: self.client_addr = None self.server_addr = None + + if self.factory.proxy_forwarded_address_header: + self.client_addr = parse_x_forwarded_for( + self.requestHeaders, + self.factory.proxy_forwarded_address_header, + self.factory.proxy_forwarded_port_header, + self.client_addr + ) + # Check for unicodeish path (or it'll crash when trying to parse) try: self.path.decode("ascii") @@ -281,7 +291,7 @@ class HTTPFactory(http.HTTPFactory): protocol = HTTPProtocol - def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30): + def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path="", websocket_connect_timeout=30, proxy_forwarded_address_header=None, proxy_forwarded_port_header=None): http.HTTPFactory.__init__(self) self.channel_layer = channel_layer self.action_logger = action_logger @@ -289,6 +299,8 @@ class HTTPFactory(http.HTTPFactory): self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.ping_interval = ping_interval + self.proxy_forwarded_address_header = proxy_forwarded_address_header + self.proxy_forwarded_port_header = proxy_forwarded_port_header # We track all sub-protocols for response channel mapping self.reply_protocols = {} # Make a factory for WebSocket protocols diff --git a/daphne/server.py b/daphne/server.py index dcb98d9..4522181 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -26,6 +26,8 @@ class Server(object): ping_timeout=30, ws_protocols=None, root_path="", + proxy_forwarded_address_header=None, + proxy_forwarded_port_header=None, verbosity=1 ): self.channel_layer = channel_layer @@ -38,6 +40,8 @@ class Server(object): self.http_timeout = http_timeout self.ping_interval = ping_interval self.ping_timeout = ping_timeout + self.proxy_forwarded_address_header = proxy_forwarded_address_header + self.proxy_forwarded_port_header = proxy_forwarded_port_header # If they did not provide a websocket timeout, default it to the # channel layer's group_expiry value if present, or one day if not. self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400) @@ -55,6 +59,8 @@ class Server(object): ping_timeout=self.ping_timeout, ws_protocols=self.ws_protocols, root_path=self.root_path, + proxy_forwarded_address_header=self.proxy_forwarded_address_header, + proxy_forwarded_port_header=self.proxy_forwarded_port_header ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere diff --git a/daphne/tests/test_http.py b/daphne/tests/test_http.py index 48e1da2..de9bef3 100644 --- a/daphne/tests/test_http.py +++ b/daphne/tests/test_http.py @@ -93,3 +93,51 @@ class TestHTTPProtocol(TestCase): # Get the disconnection notification _, disconnect_message = self.channel_layer.receive_many(["http.disconnect"]) self.assertEqual(disconnect_message['path'], "/te st-à/") + + def test_x_forwarded_for_ignored(self): + """ + Tests basic HTTP parsing + """ + self.proto.dataReceived( + b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + + b"Host: somewhere.com\r\n" + + b"X-Forwarded-For: 10.1.2.3\r\n" + + b"X-Forwarded-Port: 80\r\n" + + b"\r\n" + ) + # Get the resulting message off of the channel layer + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['client'], ['192.168.1.1', 54321]) + + def test_x_forwarded_for_parsed(self): + """ + Tests basic HTTP parsing + """ + self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' + self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' + self.proto.dataReceived( + b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + + b"Host: somewhere.com\r\n" + + b"X-Forwarded-For: 10.1.2.3\r\n" + + b"X-Forwarded-Port: 80\r\n" + + b"\r\n" + ) + # Get the resulting message off of the channel layer + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['client'], ['10.1.2.3', 80]) + + def test_x_forwarded_for_port_missing(self): + """ + Tests basic HTTP parsing + """ + self.factory.proxy_forwarded_address_header = 'X-Forwarded-For' + self.factory.proxy_forwarded_port_header = 'X-Forwarded-Port' + self.proto.dataReceived( + b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" + + b"Host: somewhere.com\r\n" + + b"X-Forwarded-For: 10.1.2.3\r\n" + + b"\r\n" + ) + # Get the resulting message off of the channel layer + _, message = self.channel_layer.receive_many(["http.request"]) + self.assertEqual(message['client'], ['10.1.2.3', 0]) diff --git a/daphne/tests/test_utils.py b/daphne/tests/test_utils.py new file mode 100644 index 0000000..61fe1a3 --- /dev/null +++ b/daphne/tests/test_utils.py @@ -0,0 +1,61 @@ +# coding: utf8 +from __future__ import unicode_literals +from unittest import TestCase + +from twisted.web.http_headers import Headers + +from ..utils import parse_x_forwarded_for + + +class TestXForwardedForParsing(TestCase): + """ + Tests that the parse_x_forwarded_for util correcly parses headers. + """ + + def test_basic(self): + headers = Headers({ + b'X-Forwarded-For': [b'10.1.2.3'], + b'X-Forwarded-Port': [b'1234'] + }) + self.assertEqual( + parse_x_forwarded_for(headers), + ['10.1.2.3', 1234] + ) + + def test_address_only(self): + headers = Headers({ + b'X-Forwarded-For': [b'10.1.2.3'], + }) + self.assertEqual( + parse_x_forwarded_for(headers), + ['10.1.2.3', 0] + ) + + def test_port_in_address(self): + headers = Headers({ + b'X-Forwarded-For': [b'10.1.2.3:5123'], + }) + self.assertEqual( + parse_x_forwarded_for(headers), + ['10.1.2.3', 5123] + ) + + def test_multiple_proxys(self): + headers = Headers({ + b'X-Forwarded-For': [b'10.1.2.3, 10.1.2.4'], + }) + self.assertEqual( + parse_x_forwarded_for(headers), + ['10.1.2.4', 0] + ) + + def test_original(self): + headers = Headers({}) + self.assertEqual( + parse_x_forwarded_for(headers, original=['127.0.0.1', 80]), + ['127.0.0.1', 80] + ) + + def test_no_original(self): + headers = Headers({}) + self.assertIsNone(parse_x_forwarded_for(headers)) diff --git a/daphne/utils.py b/daphne/utils.py new file mode 100644 index 0000000..98e6d71 --- /dev/null +++ b/daphne/utils.py @@ -0,0 +1,48 @@ + + +def parse_x_forwarded_for(headers, + address_header_name='X-Forwarded-For', + port_header_name='X-Forwarded-Port', + original=None): + """ + Parses an X-Forwarded-For header and returns a host/port pair as a list. + + @param headers: The twisted-style object containing a request's headers + @param address_header_name: The name of the expected host header + @param port_header_name: The name of the expected port header + @param original: A host/port pair that should be returned if the headers are not in the request + @return: A list containing a host (string) as the first entry and a port (int) as the second. + """ + if not address_header_name: + return original + + address_header_name = address_header_name.lower().encode("utf-8") + result = original + if headers.hasHeader(address_header_name): + address_value = headers.getRawHeaders(address_header_name)[0].decode("utf-8") + + if ',' in address_value: + address_value = address_value.split(",")[-1].strip() + + if ':' in address_value: + address_host, address_port = address_value.split(':') + result = [address_host, 0] + try: + result[1] = int(address_port) + except ValueError: + pass + else: + result = [address_value, 0] + + if port_header_name: + # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For + # header to avoid inconsistent results. + port_header_name = port_header_name.lower().encode("utf-8") + if headers.hasHeader(port_header_name): + port_value = headers.getRawHeaders(port_header_name)[0].decode("utf-8") + try: + result[1] = int(port_value) + except ValueError: + pass + + return result diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 60f7ace..87e8e74 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -9,6 +9,8 @@ from twisted.internet import defer from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory, ConnectionDeny +from .utils import parse_x_forwarded_for + logger = logging.getLogger(__name__) @@ -54,6 +56,15 @@ class WebSocketProtocol(WebSocketServerProtocol): else: self.client_addr = None self.server_addr = None + + if self.factory.proxy_forwarded_address_header: + self.client_addr = parse_x_forwarded_for( + self.requestHeaders, + self.main_factory.proxy_forwarded_address_header, + self.main_factory.proxy_forwarded_port_header, + self.client_addr + ) + # Make initial request info dict from request (we only have it here) self.path = request.path.encode("ascii") self.request_info = {