Add command-line options for proxy headers

This commit is contained in:
László Károlyi 2018-10-26 21:34:15 +02:00 committed by Andrew Godwin
parent e93643ff5a
commit 20f2bc93d4
2 changed files with 103 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import argparse import argparse
import logging import logging
import sys import sys
from argparse import ArgumentError, Namespace
from .access import AccessLogGenerator from .access import AccessLogGenerator
from .endpoints import build_endpoint_description_strings from .endpoints import build_endpoint_description_strings
@ -126,6 +127,25 @@ class CommandLineInterface(object):
default=False, default=False,
action="store_true", action="store_true",
) )
self.arg_proxy_host = self.parser.add_argument(
"--proxy-headers-host",
dest="proxy_headers_host",
help="Specify which header will be used for getting the host "
"part. Can be omitted, requires --proxy-headers to be specified "
'when passed. "X-Real-IP" (when passed by your webserver) is a '
"good candidate for this.",
default=False,
action="store",
)
self.arg_proxy_port = self.parser.add_argument(
"--proxy-headers-port",
dest="proxy_headers_port",
help="Specify which header will be used for getting the port "
"part. Can be omitted, requires --proxy-headers to be specified "
"when passed.",
default=False,
action="store",
)
self.parser.add_argument( self.parser.add_argument(
"application", "application",
help="The application to dispatch to as path.to.module:instance.path", help="The application to dispatch to as path.to.module:instance.path",
@ -140,6 +160,37 @@ class CommandLineInterface(object):
""" """
cls().run(sys.argv[1:]) cls().run(sys.argv[1:])
def _check_proxy_headers_passed(self, argument: str, args: Namespace):
"""Raise if the `--proxy-headers` weren't specified."""
if args.proxy_headers:
return
raise ArgumentError(
argument=argument,
message="--proxy-headers has to be passed for this parameter.",
)
def _get_forwarded_host(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_host:
self._check_proxy_headers_passed(argument=self.arg_proxy_host, args=args)
return args.proxy_headers_host
if args.proxy_headers:
return "X-Forwarded-For"
def _get_forwarded_port(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_port:
self._check_proxy_headers_passed(argument=self.arg_proxy_port, args=args)
return args.proxy_headers_port
if args.proxy_headers:
return "X-Forwarded-Port"
def run(self, args): def run(self, args):
""" """
Pass in raw argument list and it will decode them Pass in raw argument list and it will decode them
@ -211,12 +262,8 @@ class CommandLineInterface(object):
ws_protocols=args.ws_protocols, ws_protocols=args.ws_protocols,
root_path=args.root_path, root_path=args.root_path,
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header="X-Forwarded-For" proxy_forwarded_address_header=self._get_forwarded_host(args=args),
if args.proxy_headers proxy_forwarded_port_header=self._get_forwarded_port(args=args),
else None,
proxy_forwarded_port_header="X-Forwarded-Port"
if args.proxy_headers
else None,
proxy_forwarded_proto_header="X-Forwarded-Proto" proxy_forwarded_proto_header="X-Forwarded-Proto"
if args.proxy_headers if args.proxy_headers
else None, else None,

View File

@ -1,6 +1,7 @@
# coding: utf8 # coding: utf8
import logging import logging
from argparse import ArgumentError
from unittest import TestCase from unittest import TestCase
from daphne.cli import CommandLineInterface from daphne.cli import CommandLineInterface
@ -192,3 +193,52 @@ class TestCLIInterface(TestCase):
Tests entirely custom endpoints Tests entirely custom endpoints
""" """
self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]}) self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
def test_default_proxyheaders(self):
"""
Passing `--proxy-headers` without a parameter will use the
`X-Forwarded-For` header.
"""
self.assertCLI(
["--proxy-headers"], {"proxy_forwarded_address_header": "X-Forwarded-For"}
)
def test_custom_proxyhost(self):
"""
Passing `--proxy-headers-host` will set the used host header to
the passed one, and `--proxy-headers` is mandatory.
"""
self.assertCLI(
["--proxy-headers", "--proxy-headers-host", "blah"],
{"proxy_forwarded_address_header": "blah"},
)
with self.assertRaises(expected_exception=ArgumentError) as exc:
self.assertCLI(
["--proxy-headers-host", "blah"],
{"proxy_forwarded_address_header": "blah"},
)
self.assertEqual(exc.exception.argument_name, "--proxy-headers-host")
self.assertEqual(
exc.exception.message,
"--proxy-headers has to be passed for this parameter.",
)
def test_custom_proxyport(self):
"""
Passing `--proxy-headers-port` will set the used port header to
the passed one, and `--proxy-headers` is mandatory.
"""
self.assertCLI(
["--proxy-headers", "--proxy-headers-port", "blah2"],
{"proxy_forwarded_port_header": "blah2"},
)
with self.assertRaises(expected_exception=ArgumentError) as exc:
self.assertCLI(
["--proxy-headers-port", "blah2"],
{"proxy_forwarded_address_header": "blah2"},
)
self.assertEqual(exc.exception.argument_name, "--proxy-headers-port")
self.assertEqual(
exc.exception.message,
"--proxy-headers has to be passed for this parameter.",
)