Handle both dicts and twisted Headers (#84)

Fix #78
This commit is contained in:
NotSqrt 2017-02-15 03:15:00 +01:00 committed by Andrew Godwin
parent 412d9a48dc
commit 80bacf1ea1
2 changed files with 65 additions and 4 deletions

View File

@ -7,9 +7,9 @@ from twisted.web.http_headers import Headers
from ..utils import parse_x_forwarded_for from ..utils import parse_x_forwarded_for
class TestXForwardedForParsing(TestCase): class TestXForwardedForHttpParsing(TestCase):
""" """
Tests that the parse_x_forwarded_for util correcly parses headers. Tests that the parse_x_forwarded_for util correctly parses twisted Header.
""" """
def test_basic(self): def test_basic(self):
@ -59,3 +59,57 @@ class TestXForwardedForParsing(TestCase):
def test_no_original(self): def test_no_original(self):
headers = Headers({}) headers = Headers({})
self.assertIsNone(parse_x_forwarded_for(headers)) self.assertIsNone(parse_x_forwarded_for(headers))
class TestXForwardedForWsParsing(TestCase):
"""
Tests that the parse_x_forwarded_for util correctly parses dict headers.
"""
def test_basic(self):
headers = {
b'X-Forwarded-For': b'10.1.2.3',
b'X-Forwarded-Port': b'1234',
}
self.assertEqual(
parse_x_forwarded_for(headers),
['10.1.2.3', 1234]
)
def test_address_only(self):
headers = {
b'X-Forwarded-For': b'10.1.2.3',
}
self.assertEqual(
parse_x_forwarded_for(headers),
['10.1.2.3', 0]
)
def test_port_in_address(self):
headers = {
b'X-Forwarded-For': b'10.1.2.3:5123',
}
self.assertEqual(
parse_x_forwarded_for(headers),
['10.1.2.3', 5123]
)
def test_multiple_proxys(self):
headers = {
b'X-Forwarded-For': b'10.1.2.3, 10.1.2.4',
}
self.assertEqual(
parse_x_forwarded_for(headers),
['10.1.2.4', 0]
)
def test_original(self):
headers = {}
self.assertEqual(
parse_x_forwarded_for(headers, original=['127.0.0.1', 80]),
['127.0.0.1', 80]
)
def test_no_original(self):
headers = {}
self.assertIsNone(parse_x_forwarded_for(headers))

View File

@ -1,6 +1,13 @@
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
def header_value(headers, header_name):
value = headers[header_name]
if isinstance(value, list):
value = value[0]
return value.decode("utf-8")
def parse_x_forwarded_for(headers, def parse_x_forwarded_for(headers,
address_header_name='X-Forwarded-For', address_header_name='X-Forwarded-For',
port_header_name='X-Forwarded-Port', port_header_name='X-Forwarded-Port',
@ -27,7 +34,7 @@ def parse_x_forwarded_for(headers,
address_header_name = address_header_name.lower().encode("utf-8") address_header_name = address_header_name.lower().encode("utf-8")
result = original result = original
if address_header_name in headers: if address_header_name in headers:
address_value = headers[address_header_name][0].decode("utf-8") address_value = header_value(headers, address_header_name)
if ',' in address_value: if ',' in address_value:
address_value = address_value.split(",")[-1].strip() address_value = address_value.split(",")[-1].strip()
@ -47,7 +54,7 @@ def parse_x_forwarded_for(headers,
# header to avoid inconsistent results. # header to avoid inconsistent results.
port_header_name = port_header_name.lower().encode("utf-8") port_header_name = port_header_name.lower().encode("utf-8")
if port_header_name in headers: if port_header_name in headers:
port_value = headers[port_header_name][0].decode("utf-8") port_value = header_value(headers, port_header_name)
try: try:
result[1] = int(port_value) result[1] = int(port_value)
except ValueError: except ValueError: