demo: parse X-Forwarded-Proto header

This commit is contained in:
Tilmann Becker 2017-08-23 17:01:15 +02:00
parent 3161715238
commit 9ef69b99b2
2 changed files with 16 additions and 3 deletions

View File

@ -81,12 +81,18 @@ class WebRequest(http.Request):
self.server_addr = None
if self.factory.proxy_forwarded_address_header:
self.client_addr = parse_x_forwarded_for(
client, proto = parse_x_forwarded_for(
self.requestHeaders,
self.factory.proxy_forwarded_address_header,
self.factory.proxy_forwarded_port_header,
self.client_addr
original=self.client_addr
)
self.client_addr = client
if proto:
# If we get a proto header, force ssl on or off (affects self.isSecure)
self.setHost(host=self.host.host.encode('utf-8'),
port=self.host.port,
ssl=True if proto == 'https' else False)
# Check for unicodeish path (or it'll crash when trying to parse)
try:

View File

@ -11,6 +11,7 @@ def header_value(headers, header_name):
def parse_x_forwarded_for(headers,
address_header_name='X-Forwarded-For',
port_header_name='X-Forwarded-Port',
proto_header_name='X-Forwarded-Proto',
original=None):
"""
Parses an X-Forwarded-For header and returns a host/port pair as a list.
@ -18,6 +19,7 @@ def parse_x_forwarded_for(headers,
@param headers: The twisted-style object containing a request's headers
@param address_header_name: The name of the expected host header
@param port_header_name: The name of the expected port header
@param proto_header_name: The name of the expected protocol header
@param original: A host/port pair that should be returned if the headers are not in the request
@return: A list containing a host (string) as the first entry and a port (int) as the second.
"""
@ -33,6 +35,7 @@ def parse_x_forwarded_for(headers,
address_header_name = address_header_name.lower().encode("utf-8")
result = original
proto_value = None
if address_header_name in headers:
address_value = header_value(headers, address_header_name)
@ -52,4 +55,8 @@ def parse_x_forwarded_for(headers,
except ValueError:
pass
return result
proto_header_name = proto_header_name.lower().encode("utf-8")
if proto_header_name in headers:
proto_value = header_value(headers, proto_header_name)
return result, proto_value