mirror of
https://github.com/django/daphne.git
synced 2024-11-22 07:56:34 +03:00
parent
412d9a48dc
commit
80bacf1ea1
|
@ -7,9 +7,9 @@ from twisted.web.http_headers import Headers
|
||||||
from ..utils import parse_x_forwarded_for
|
from ..utils import parse_x_forwarded_for
|
||||||
|
|
||||||
|
|
||||||
class TestXForwardedForParsing(TestCase):
|
class TestXForwardedForHttpParsing(TestCase):
|
||||||
"""
|
"""
|
||||||
Tests that the parse_x_forwarded_for util correcly parses headers.
|
Tests that the parse_x_forwarded_for util correctly parses twisted Header.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
@ -59,3 +59,57 @@ class TestXForwardedForParsing(TestCase):
|
||||||
def test_no_original(self):
|
def test_no_original(self):
|
||||||
headers = Headers({})
|
headers = Headers({})
|
||||||
self.assertIsNone(parse_x_forwarded_for(headers))
|
self.assertIsNone(parse_x_forwarded_for(headers))
|
||||||
|
|
||||||
|
|
||||||
|
class TestXForwardedForWsParsing(TestCase):
|
||||||
|
"""
|
||||||
|
Tests that the parse_x_forwarded_for util correctly parses dict headers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
headers = {
|
||||||
|
b'X-Forwarded-For': b'10.1.2.3',
|
||||||
|
b'X-Forwarded-Port': b'1234',
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 1234]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_address_only(self):
|
||||||
|
headers = {
|
||||||
|
b'X-Forwarded-For': b'10.1.2.3',
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_port_in_address(self):
|
||||||
|
headers = {
|
||||||
|
b'X-Forwarded-For': b'10.1.2.3:5123',
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.3', 5123]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multiple_proxys(self):
|
||||||
|
headers = {
|
||||||
|
b'X-Forwarded-For': b'10.1.2.3, 10.1.2.4',
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers),
|
||||||
|
['10.1.2.4', 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_original(self):
|
||||||
|
headers = {}
|
||||||
|
self.assertEqual(
|
||||||
|
parse_x_forwarded_for(headers, original=['127.0.0.1', 80]),
|
||||||
|
['127.0.0.1', 80]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_original(self):
|
||||||
|
headers = {}
|
||||||
|
self.assertIsNone(parse_x_forwarded_for(headers))
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
|
|
||||||
|
def header_value(headers, header_name):
|
||||||
|
value = headers[header_name]
|
||||||
|
if isinstance(value, list):
|
||||||
|
value = value[0]
|
||||||
|
return value.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def parse_x_forwarded_for(headers,
|
def parse_x_forwarded_for(headers,
|
||||||
address_header_name='X-Forwarded-For',
|
address_header_name='X-Forwarded-For',
|
||||||
port_header_name='X-Forwarded-Port',
|
port_header_name='X-Forwarded-Port',
|
||||||
|
@ -27,7 +34,7 @@ def parse_x_forwarded_for(headers,
|
||||||
address_header_name = address_header_name.lower().encode("utf-8")
|
address_header_name = address_header_name.lower().encode("utf-8")
|
||||||
result = original
|
result = original
|
||||||
if address_header_name in headers:
|
if address_header_name in headers:
|
||||||
address_value = headers[address_header_name][0].decode("utf-8")
|
address_value = header_value(headers, address_header_name)
|
||||||
|
|
||||||
if ',' in address_value:
|
if ',' in address_value:
|
||||||
address_value = address_value.split(",")[-1].strip()
|
address_value = address_value.split(",")[-1].strip()
|
||||||
|
@ -47,7 +54,7 @@ def parse_x_forwarded_for(headers,
|
||||||
# header to avoid inconsistent results.
|
# header to avoid inconsistent results.
|
||||||
port_header_name = port_header_name.lower().encode("utf-8")
|
port_header_name = port_header_name.lower().encode("utf-8")
|
||||||
if port_header_name in headers:
|
if port_header_name in headers:
|
||||||
port_value = headers[port_header_name][0].decode("utf-8")
|
port_value = header_value(headers, port_header_name)
|
||||||
try:
|
try:
|
||||||
result[1] = int(port_value)
|
result[1] = int(port_value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user