diff --git a/daphne/cli.py b/daphne/cli.py index 1df2ebd..ffe6418 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -6,6 +6,7 @@ from .access import AccessLogGenerator from .endpoints import build_endpoint_description_strings from .server import Server from .utils import import_by_path +from argparse import ArgumentError, Namespace logger = logging.getLogger(__name__) @@ -165,6 +166,40 @@ class CommandLineInterface(object): """ cls().run(sys.argv[1:]) + def _check_proxy_headers_passed(self, argument: str, args: Namespace): + """Raise if the `--proxy-headers` weren't specified.""" + if args.proxy_headers: + return + raise ArgumentError( + argument=argument, + message="--proxy-headers has to be passed for this parameter.") + + def _get_forwarded_host(self, args: Namespace): + """ + Return the default host header from which the remote hostname/ip + will be extracted. + """ + if args.proxy_headers_host: + self._check_proxy_headers_passed( + argument=self.arg_proxy_host, args=args) + return args.proxy_headers_host + if args.proxy_headers: + return "X-Forwarded-For" + + def _get_forwarded_port(self, args: Namespace): + """ + Return the default host header from which the remote hostname/ip + will be extracted. + """ + if args.proxy_headers_port: + self._check_proxy_headers_passed( + argument=self.arg_proxy_port, args=args) + return args.proxy_headers_port + if args.proxy_headers: + return "X-Forwarded-Port" + + def run(self, args): + def run(self, args): """ Pass in raw argument list and it will decode them @@ -231,7 +266,7 @@ class CommandLineInterface(object): ws_protocols=args.ws_protocols, root_path=args.root_path, verbosity=args.verbosity, - proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, - proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, + proxy_forwarded_address_header=self._get_forwarded_host(args=args), + proxy_forwarded_port_header=self._get_forwarded_port(args=args), ) self.server.run()