mirror of
https://github.com/django/daphne.git
synced 2025-04-20 00:32:09 +03:00
Merge a7ccae7025
into 84466d4ae4
This commit is contained in:
commit
c61b1c49dc
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from argparse import ArgumentError, Namespace
|
||||
|
||||
from .access import AccessLogGenerator
|
||||
from .endpoints import build_endpoint_description_strings
|
||||
|
@ -132,6 +133,25 @@ class CommandLineInterface(object):
|
|||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
self.arg_proxy_host = self.parser.add_argument(
|
||||
"--proxy-headers-host",
|
||||
dest="proxy_headers_host",
|
||||
help="Specify which header will be used for getting the host "
|
||||
"part. Can be omitted, requires --proxy-headers to be specified "
|
||||
"when passed. \"X-Real-IP\" (when passed by your webserver) is a "
|
||||
"good candidate for this.",
|
||||
default=False,
|
||||
action="store",
|
||||
)
|
||||
self.arg_proxy_port = self.parser.add_argument(
|
||||
"--proxy-headers-port",
|
||||
dest="proxy_headers_port",
|
||||
help="Specify which header will be used for getting the port "
|
||||
"part. Can be omitted, requires --proxy-headers to be specified "
|
||||
"when passed.",
|
||||
default=False,
|
||||
action="store",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"application",
|
||||
help="The application to dispatch to as path.to.module:instance.path",
|
||||
|
@ -146,6 +166,38 @@ class CommandLineInterface(object):
|
|||
"""
|
||||
cls().run(sys.argv[1:])
|
||||
|
||||
def _check_proxy_headers_passed(self, argument: str, args: Namespace):
|
||||
"""Raise if the `--proxy-headers` weren't specified."""
|
||||
if args.proxy_headers:
|
||||
return
|
||||
raise ArgumentError(
|
||||
argument=argument,
|
||||
message="--proxy-headers has to be passed for this parameter.")
|
||||
|
||||
def _get_forwarded_host(self, args: Namespace):
|
||||
"""
|
||||
Return the default host header from which the remote hostname/ip
|
||||
will be extracted.
|
||||
"""
|
||||
if args.proxy_headers_host:
|
||||
self._check_proxy_headers_passed(
|
||||
argument=self.arg_proxy_host, args=args)
|
||||
return args.proxy_headers_host
|
||||
if args.proxy_headers:
|
||||
return "X-Forwarded-For"
|
||||
|
||||
def _get_forwarded_port(self, args: Namespace):
|
||||
"""
|
||||
Return the default host header from which the remote hostname/ip
|
||||
will be extracted.
|
||||
"""
|
||||
if args.proxy_headers_port:
|
||||
self._check_proxy_headers_passed(
|
||||
argument=self.arg_proxy_port, args=args)
|
||||
return args.proxy_headers_port
|
||||
if args.proxy_headers:
|
||||
return "X-Forwarded-Port"
|
||||
|
||||
def run(self, args):
|
||||
"""
|
||||
Pass in raw argument list and it will decode them
|
||||
|
@ -212,7 +264,7 @@ class CommandLineInterface(object):
|
|||
ws_protocols=args.ws_protocols,
|
||||
root_path=args.root_path,
|
||||
verbosity=args.verbosity,
|
||||
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
|
||||
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
|
||||
proxy_forwarded_address_header=self._get_forwarded_host(args=args),
|
||||
proxy_forwarded_port_header=self._get_forwarded_port(args=args),
|
||||
)
|
||||
self.server.run()
|
||||
|
|
|
@ -15,11 +15,11 @@ def import_by_path(path):
|
|||
return target
|
||||
|
||||
|
||||
def header_value(headers, header_name):
|
||||
def header_value(headers, header_name) -> str:
|
||||
value = headers[header_name]
|
||||
if isinstance(value, list):
|
||||
value = value[0]
|
||||
return value.decode("utf-8")
|
||||
return value.decode("utf-8") if type(value) is bytes else value
|
||||
|
||||
|
||||
def parse_x_forwarded_for(headers,
|
||||
|
@ -43,7 +43,12 @@ def parse_x_forwarded_for(headers,
|
|||
headers = dict(headers.getAllRawHeaders())
|
||||
|
||||
# Lowercase all header names in the dict
|
||||
headers = {name.lower(): values for name, values in headers.items()}
|
||||
new_headers = dict()
|
||||
for name, values in headers.items():
|
||||
name = name.lower()
|
||||
name = name if type(name) is bytes else name.encode("utf-8")
|
||||
new_headers[name] = values
|
||||
headers = new_headers
|
||||
|
||||
# Make sure header names are bytes (values are checked in header_value)
|
||||
assert all(isinstance(name, bytes) for name in headers.keys())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf8
|
||||
|
||||
import logging
|
||||
from argparse import ArgumentError
|
||||
from unittest import TestCase
|
||||
|
||||
from daphne.cli import CommandLineInterface
|
||||
|
@ -235,3 +236,61 @@ class TestCLIInterface(TestCase):
|
|||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_default_proxyheaders(self):
|
||||
"""
|
||||
Passing `--proxy-headers` without a parameter will use the
|
||||
`X-Forwarded-For` header.
|
||||
"""
|
||||
self.assertCLI(
|
||||
["--proxy-headers"],
|
||||
{
|
||||
"proxy_forwarded_address_header": "X-Forwarded-For",
|
||||
},
|
||||
)
|
||||
|
||||
def test_custom_proxyhost(self):
|
||||
"""
|
||||
Passing `--proxy-headers-host` will set the used host header to
|
||||
the passed one, and `--proxy-headers` is mandatory.
|
||||
"""
|
||||
self.assertCLI(
|
||||
["--proxy-headers", "--proxy-headers-host", "blah"],
|
||||
{
|
||||
"proxy_forwarded_address_header": "blah",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(expected_exception=ArgumentError) as exc:
|
||||
self.assertCLI(
|
||||
["--proxy-headers-host", "blah"],
|
||||
{
|
||||
"proxy_forwarded_address_header": "blah",
|
||||
},
|
||||
)
|
||||
self.assertEqual(exc.exception.argument_name, "--proxy-headers-host")
|
||||
self.assertEqual(
|
||||
exc.exception.message,
|
||||
"--proxy-headers has to be passed for this parameter.")
|
||||
|
||||
def test_custom_proxyport(self):
|
||||
"""
|
||||
Passing `--proxy-headers-port` will set the used port header to
|
||||
the passed one, and `--proxy-headers` is mandatory.
|
||||
"""
|
||||
self.assertCLI(
|
||||
["--proxy-headers", "--proxy-headers-port", "blah2"],
|
||||
{
|
||||
"proxy_forwarded_port_header": "blah2",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(expected_exception=ArgumentError) as exc:
|
||||
self.assertCLI(
|
||||
["--proxy-headers-port", "blah2"],
|
||||
{
|
||||
"proxy_forwarded_address_header": "blah2",
|
||||
},
|
||||
)
|
||||
self.assertEqual(exc.exception.argument_name, "--proxy-headers-port")
|
||||
self.assertEqual(
|
||||
exc.exception.message,
|
||||
"--proxy-headers has to be passed for this parameter.")
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestXForwardedForHttpParsing(TestCase):
|
|||
["10.1.2.3", 0]
|
||||
)
|
||||
|
||||
def test_v6_address(self):
|
||||
def test_v6_address_1(self):
|
||||
headers = Headers({
|
||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
||||
})
|
||||
|
@ -84,7 +84,17 @@ class TestXForwardedForWsParsing(TestCase):
|
|||
["10.1.2.3", 0]
|
||||
)
|
||||
|
||||
def test_v6_address(self):
|
||||
def test_non_bytes_header(self):
|
||||
"""The passed headers can be non-bytes too."""
|
||||
headers = {
|
||||
"X-Forwarded-For": "10.1.2.3",
|
||||
}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
["10.1.2.3", 0]
|
||||
)
|
||||
|
||||
def test_v6_address_2(self):
|
||||
headers = {
|
||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user