diff --git a/daphne/cli.py b/daphne/cli.py index 28cf1b9..dc715a6 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -1,6 +1,7 @@ import argparse import logging import sys +from argparse import ArgumentError, Namespace from .access import AccessLogGenerator from .endpoints import build_endpoint_description_strings @@ -126,6 +127,25 @@ class CommandLineInterface(object): default=False, action="store_true", ) + self.arg_proxy_host = self.parser.add_argument( + "--proxy-headers-host", + dest="proxy_headers_host", + help="Specify which header will be used for getting the host " + "part. Can be omitted, requires --proxy-headers to be specified " + 'when passed. "X-Real-IP" (when passed by your webserver) is a ' + "good candidate for this.", + default=False, + action="store", + ) + self.arg_proxy_port = self.parser.add_argument( + "--proxy-headers-port", + dest="proxy_headers_port", + help="Specify which header will be used for getting the port " + "part. Can be omitted, requires --proxy-headers to be specified " + "when passed.", + default=False, + action="store", + ) self.parser.add_argument( "application", help="The application to dispatch to as path.to.module:instance.path", @@ -140,6 +160,37 @@ class CommandLineInterface(object): """ cls().run(sys.argv[1:]) + def _check_proxy_headers_passed(self, argument: str, args: Namespace): + """Raise if the `--proxy-headers` weren't specified.""" + if args.proxy_headers: + return + raise ArgumentError( + argument=argument, + message="--proxy-headers has to be passed for this parameter.", + ) + + def _get_forwarded_host(self, args: Namespace): + """ + Return the default host header from which the remote hostname/ip + will be extracted. + """ + if args.proxy_headers_host: + self._check_proxy_headers_passed(argument=self.arg_proxy_host, args=args) + return args.proxy_headers_host + if args.proxy_headers: + return "X-Forwarded-For" + + def _get_forwarded_port(self, args: Namespace): + """ + Return the default host header from which the remote hostname/ip + will be extracted. + """ + if args.proxy_headers_port: + self._check_proxy_headers_passed(argument=self.arg_proxy_port, args=args) + return args.proxy_headers_port + if args.proxy_headers: + return "X-Forwarded-Port" + def run(self, args): """ Pass in raw argument list and it will decode them @@ -211,12 +262,8 @@ class CommandLineInterface(object): ws_protocols=args.ws_protocols, root_path=args.root_path, verbosity=args.verbosity, - proxy_forwarded_address_header="X-Forwarded-For" - if args.proxy_headers - else None, - proxy_forwarded_port_header="X-Forwarded-Port" - if args.proxy_headers - else None, + proxy_forwarded_address_header=self._get_forwarded_host(args=args), + proxy_forwarded_port_header=self._get_forwarded_port(args=args), proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, diff --git a/tests/test_cli.py b/tests/test_cli.py index 7bb45dc..2bbcc42 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ # coding: utf8 import logging +from argparse import ArgumentError from unittest import TestCase from daphne.cli import CommandLineInterface @@ -192,3 +193,52 @@ class TestCLIInterface(TestCase): Tests entirely custom endpoints """ self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]}) + + def test_default_proxyheaders(self): + """ + Passing `--proxy-headers` without a parameter will use the + `X-Forwarded-For` header. + """ + self.assertCLI( + ["--proxy-headers"], {"proxy_forwarded_address_header": "X-Forwarded-For"} + ) + + def test_custom_proxyhost(self): + """ + Passing `--proxy-headers-host` will set the used host header to + the passed one, and `--proxy-headers` is mandatory. + """ + self.assertCLI( + ["--proxy-headers", "--proxy-headers-host", "blah"], + {"proxy_forwarded_address_header": "blah"}, + ) + with self.assertRaises(expected_exception=ArgumentError) as exc: + self.assertCLI( + ["--proxy-headers-host", "blah"], + {"proxy_forwarded_address_header": "blah"}, + ) + self.assertEqual(exc.exception.argument_name, "--proxy-headers-host") + self.assertEqual( + exc.exception.message, + "--proxy-headers has to be passed for this parameter.", + ) + + def test_custom_proxyport(self): + """ + Passing `--proxy-headers-port` will set the used port header to + the passed one, and `--proxy-headers` is mandatory. + """ + self.assertCLI( + ["--proxy-headers", "--proxy-headers-port", "blah2"], + {"proxy_forwarded_port_header": "blah2"}, + ) + with self.assertRaises(expected_exception=ArgumentError) as exc: + self.assertCLI( + ["--proxy-headers-port", "blah2"], + {"proxy_forwarded_address_header": "blah2"}, + ) + self.assertEqual(exc.exception.argument_name, "--proxy-headers-port") + self.assertEqual( + exc.exception.message, + "--proxy-headers has to be passed for this parameter.", + )